diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..c42767d --- /dev/null +++ b/.coveragerc @@ -0,0 +1,11 @@ +[run] +omit = + src/tests/backend/* + */test_*.py + */*_test.py + */__init__.py + +[report] +exclude_lines = + pragma: no cover + if __name__ == "__main__": diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 34a2f24..2f0f94d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,11 +1,13 @@ name: Test Workflow with Coverage - Code-Gen on: + workflow_dispatch: push: branches: - main - dev - demo + - hotfix pull_request: types: - opened @@ -16,52 +18,11 @@ on: - main - dev - demo + - hotfix jobs: -# frontend_tests: -# runs-on: ubuntu-latest - -# steps: -# - name: Checkout code -# uses: actions/checkout@v3 - -# - name: Set up Node.js -# uses: actions/setup-node@v3 -# with: -# node-version: '20' - -# - name: Check if Frontend Test Files Exist -# id: check_frontend_tests -# run: | -# if [ -z "$(find src/tests/frontend -type f -name '*.test.js' -o -name '*.test.ts' -o -name '*.test.tsx')" ]; then -# echo "No frontend test files found, skipping frontend tests." -# echo "skip_frontend_tests=true" >> $GITHUB_ENV -# else -# echo "Frontend test files found, running tests." -# echo "skip_frontend_tests=false" >> $GITHUB_ENV -# fi - -# - name: Install Frontend Dependencies -# if: env.skip_frontend_tests == 'false' -# run: | -# cd src/frontend -# npm install - -# - name: Run Frontend Tests with Coverage -# if: env.skip_frontend_tests == 'false' -# run: | -# cd src/tests/frontend -# npm run test -- --coverage - -# - name: Skip Frontend Tests -# if: env.skip_frontend_tests == 'true' -# run: | -# echo "Skipping frontend tests because no test files were found." - backend_tests: runs-on: ubuntu-latest - - steps: - name: Checkout code uses: actions/checkout@v3 @@ -71,36 +32,22 @@ jobs: with: python-version: '3.11' - - name: Install Backend Dependencies + - name: Install Dependencies run: | python -m pip install --upgrade pip pip install -r src/backend/requirements.txt pip install -r src/frontend/requirements.txt - pip install pytest-cov - pip install pytest-asyncio + pip install pytest-cov pytest-asyncio + - name: Set PYTHONPATH run: echo "PYTHONPATH=$PWD/src/backend" >> $GITHUB_ENV - - name: Check if Backend Test Files Exist - id: check_backend_tests - run: | - if [ -z "$(find src/tests/backend -type f -name '*_test.py')" ]; then - echo "No backend test files found, skipping backend tests." - echo "skip_backend_tests=true" >> $GITHUB_ENV - else - echo "Backend test files found, running tests." - echo "skip_backend_tests=false" >> $GITHUB_ENV - fi - - name: Run Backend Tests with Coverage - if: env.skip_backend_tests == 'false' run: | cd src - pytest --cov=. --cov-report=term-missing --cov-report=xml - - - - - name: Skip Backend Tests - if: env.skip_backend_tests == 'true' - run: | - echo "Skipping backend tests because no test files were found." + # only measure coverage for src/backend, omit tests via .coveragerc + pytest \ + --cov=backend \ + --cov-report=term-missing \ + --cov-report=xml \ + --cov-config=../.coveragerc diff --git a/src/backend/common/storage/blob_base.py b/src/backend/common/storage/blob_base.py index 4495584..7375299 100644 --- a/src/backend/common/storage/blob_base.py +++ b/src/backend/common/storage/blob_base.py @@ -25,7 +25,7 @@ async def upload_file( Returns: Dict containing upload details (url, size, etc.) """ - pass + pass # pragma: no cover @abstractmethod async def get_file(self, blob_path: str) -> BinaryIO: @@ -38,7 +38,7 @@ async def get_file(self, blob_path: str) -> BinaryIO: Returns: File content as a binary stream """ - pass + pass # pragma: no cover @abstractmethod async def delete_file(self, blob_path: str) -> bool: @@ -51,7 +51,7 @@ async def delete_file(self, blob_path: str) -> bool: Returns: True if deletion was successful """ - pass + pass # pragma: no cover @abstractmethod async def list_files(self, prefix: Optional[str] = None) -> list[Dict[str, Any]]: @@ -64,4 +64,4 @@ async def list_files(self, prefix: Optional[str] = None) -> list[Dict[str, Any]] Returns: List of blob details """ - pass + pass # pragma: no cover diff --git a/src/backend/sql_agents/convert_script.py b/src/backend/sql_agents/convert_script.py index 3686886..60e6378 100644 --- a/src/backend/sql_agents/convert_script.py +++ b/src/backend/sql_agents/convert_script.py @@ -34,7 +34,7 @@ logger.setLevel(logging.DEBUG) -async def convert_script( +async def convert_script( # pragma: no cover source_script, file: FileRecord, batch_service: BatchService, diff --git a/src/tests/backend/api/auth/__init__.py b/src/tests/backend/api/auth/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/tests/backend/api/auth/auth_utils_test.py b/src/tests/backend/api/auth/auth_utils_test.py new file mode 100644 index 0000000..a9a4f95 --- /dev/null +++ b/src/tests/backend/api/auth/auth_utils_test.py @@ -0,0 +1,83 @@ +import base64 +import json +from unittest.mock import MagicMock + +from api.auth.auth_utils import UserDetails, get_authenticated_user, get_tenant_id + +from fastapi import HTTPException, Request + +import pytest + + +def test_get_tenant_id_valid(): + payload = {"tid": "tenant123"} + encoded = base64.b64encode(json.dumps(payload).encode("utf-8")).decode("utf-8") + + result = get_tenant_id(encoded) + assert result == "tenant123" + + +def test_get_tenant_id_invalid(): + invalid_b64 = "invalid_base64_string" + result = get_tenant_id(invalid_b64) + assert result == "" + + +def test_user_details_initialization_with_tenant(): + payload = {"tid": "tenant456"} + encoded = base64.b64encode(json.dumps(payload).encode("utf-8")).decode("utf-8") + + user_data = { + "user_principal_id": "user1", + "user_name": "John Doe", + "auth_provider": "aad", + "auth_token": "fake_token", + "client_principal_b64": encoded, + } + + user = UserDetails(user_data) + assert user.user_principal_id == "user1" + assert user.user_name == "John Doe" + assert user.tenant_id == "tenant456" + + +def test_user_details_initialization_without_tenant(): + user_data = { + "user_principal_id": "user2", + "user_name": "Jane Doe", + "auth_provider": "aad", + "auth_token": "fake_token", + "client_principal_b64": "your_base_64_encoded_token", + } + + user = UserDetails(user_data) + assert user.tenant_id is None + + +def test_get_authenticated_user_valid(): + headers = { + "x-ms-client-principal-id": "user3", + } + + mock_request = MagicMock(spec=Request) + mock_request.headers = headers + + user = get_authenticated_user(mock_request) + assert isinstance(user, UserDetails) + assert user.user_principal_id == "user3" + + +def test_get_authenticated_user_raises_http_exception(monkeypatch): + # Mocking a development environment with no user principal in sample_user + sample_user_mock = {"some-header": "some-value"} + + monkeypatch.setattr("api.auth.auth_utils.sample_user", sample_user_mock) + + mock_request = MagicMock(spec=Request) + mock_request.headers = {} + + with pytest.raises(HTTPException) as exc_info: + get_authenticated_user(mock_request) + + assert exc_info.value.status_code == 401 + assert exc_info.value.detail == "User not authenticated" diff --git a/src/tests/backend/api/status_updates_test.py b/src/tests/backend/api/status_updates_test.py new file mode 100644 index 0000000..027465a --- /dev/null +++ b/src/tests/backend/api/status_updates_test.py @@ -0,0 +1,139 @@ +import asyncio +import uuid +from unittest.mock import AsyncMock, patch + +from api import status_updates + +from common.models.api import AgentType, FileProcessUpdate, FileResult, ProcessStatus + +import pytest + + +@pytest.fixture +def file_process_update(): + return FileProcessUpdate( + batch_id=uuid.uuid4(), + file_id=uuid.uuid4(), + process_status=ProcessStatus.IN_PROGRESS, + agent_type=AgentType.MIGRATOR, + agent_message="Processing in progress", + file_result=FileResult.INFO + ) + + +@pytest.fixture +def mock_websocket(): + return AsyncMock() + + +@pytest.mark.asyncio +async def test_send_status_update_async_success(file_process_update): + mock_websocket = AsyncMock() + status_updates.app_connection_manager.add_connection(file_process_update.batch_id, mock_websocket) + + with patch("api.status_updates.json.dumps", return_value='{"batch_id": "test_batch", "status": "Processing", "progress": 50}'): + await status_updates.send_status_update_async(file_process_update) + + mock_websocket.send_text.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_status_update_async_no_connection(file_process_update): + # No connection added + with patch("api.status_updates.logger") as mock_logger: + await status_updates.send_status_update_async(file_process_update) + mock_logger.warning.assert_called_once_with( + "No connection found for batch ID: %s", file_process_update.batch_id + ) + + +def test_send_status_update_success(file_process_update): + mock_websocket = AsyncMock() + loop = asyncio.new_event_loop() + + with patch("api.status_updates.asyncio.get_event_loop", return_value=loop): + with patch("api.status_updates.asyncio.run_coroutine_threadsafe") as mock_run: + status_updates.app_connection_manager.add_connection(str(file_process_update.batch_id), mock_websocket) + + with patch("api.status_updates.json.dumps", return_value='{}'): + status_updates.send_status_update(file_process_update) + + mock_run.assert_called_once() + + +def test_send_status_update_no_connection(file_process_update): + with patch("api.status_updates.logger") as mock_logger: + status_updates.send_status_update(file_process_update) + + mock_logger.warning.assert_called() + args, kwargs = mock_logger.warning.call_args + assert "No connection found for batch ID" in args[0] + + +@pytest.mark.asyncio +async def test_close_connection_success(file_process_update, mock_websocket): + status_updates.app_connection_manager.add_connection(file_process_update.batch_id, mock_websocket) + loop = asyncio.new_event_loop() + + with patch("api.status_updates.asyncio.get_event_loop", return_value=loop): + with patch("api.status_updates.asyncio.run_coroutine_threadsafe") as mock_run: + with patch("api.status_updates.logger") as mock_logger: + await status_updates.close_connection(file_process_update.batch_id) + + mock_run.assert_called_once() + mock_logger.info.assert_any_call("Connection closed for batch ID: %s", file_process_update.batch_id) + mock_logger.info.assert_any_call("Connection removed for batch ID: %s", file_process_update.batch_id) + + +@pytest.mark.asyncio +async def test_close_connection_no_connection(file_process_update): + with patch("api.status_updates.logger") as mock_logger: + await status_updates.close_connection(file_process_update.batch_id) + + mock_logger.warning.assert_called_once_with( + "No connection found for batch ID: %s", file_process_update.batch_id + ) + mock_logger.info.assert_called_once_with( + "Connection removed for batch ID: %s", file_process_update.batch_id + ) + + +# Test the connection manager directly +def test_connection_manager_methods(): + # Get the actual connection manager instance + manager = status_updates.app_connection_manager + + # Test the get_connection method + batch_id = uuid.uuid4() + assert manager.get_connection(batch_id) is None + + # Test add_connection method + mock_websocket = AsyncMock() + manager.add_connection(batch_id, mock_websocket) + assert manager.get_connection(batch_id) == mock_websocket + + # Test overwriting an existing connection + new_mock_websocket = AsyncMock() + manager.add_connection(batch_id, new_mock_websocket) + assert manager.get_connection(batch_id) == new_mock_websocket + + # Test remove_connection method + manager.remove_connection(batch_id) + assert manager.get_connection(batch_id) is None + + # Test removing a non-existent connection (should not raise an error) + manager.remove_connection(uuid.uuid4()) + + +def test_send_status_update_exception(file_process_update): + mock_websocket = AsyncMock() + status_updates.app_connection_manager.add_connection(str(file_process_update.batch_id), mock_websocket) + + with patch("api.status_updates.asyncio.get_event_loop") as mock_loop: + mock_loop.return_value = asyncio.new_event_loop() + with patch("api.status_updates.json.dumps", return_value='{}'): + with patch("api.status_updates.asyncio.run_coroutine_threadsafe", side_effect=Exception("send error")): + with patch("api.status_updates.logger") as mock_logger: + status_updates.send_status_update(file_process_update) + mock_logger.error.assert_called_once() + assert "Failed to send message" in mock_logger.error.call_args[0][0] diff --git a/src/tests/backend/common/database/cosmosdb_test.py b/src/tests/backend/common/database/cosmosdb_test.py index df53fde..3b9c128 100644 --- a/src/tests/backend/common/database/cosmosdb_test.py +++ b/src/tests/backend/common/database/cosmosdb_test.py @@ -1,31 +1,24 @@ -import os -import sys -# Add backend directory to sys.path -sys.path.insert( - 0, - os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../..", "backend")), -) -from datetime import datetime, timezone # noqa: E402 -from unittest import mock # noqa: E402 -from unittest.mock import AsyncMock # noqa: E402 -from uuid import uuid4 # noqa: E402 +from datetime import datetime, timezone +from unittest import mock +from unittest.mock import AsyncMock +from uuid import uuid4 -from azure.cosmos.aio import CosmosClient # noqa: E402 -from azure.cosmos.exceptions import CosmosResourceExistsError # noqa: E402 +from azure.cosmos.aio import CosmosClient +from azure.cosmos.exceptions import CosmosResourceExistsError -from common.database.cosmosdb import ( # noqa: E402 +from common.database.cosmosdb import ( CosmosDBClient, ) -from common.models.api import ( # noqa: E402 +from common.models.api import ( AgentType, AuthorRole, BatchRecord, FileRecord, LogType, - ProcessStatus, -) # noqa: E402 + ProcessStatus +) -import pytest # noqa: E402 +import pytest # Mocked data for the test endpoint = "https://fake.cosmosdb.azure.com" diff --git a/src/tests/backend/sql_agents/agents/agent_base_test.py b/src/tests/backend/sql_agents/agents/agent_base_test.py new file mode 100644 index 0000000..f2a821c --- /dev/null +++ b/src/tests/backend/sql_agents/agents/agent_base_test.py @@ -0,0 +1,98 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +import pytest_asyncio + +from semantic_kernel.functions import KernelArguments + +from sql_agents.agents.agent_base import BaseSQLAgent +from sql_agents.helpers.models import AgentType + + +# Concrete subclass for testing +class DummyResponse: + @classmethod + def model_json_schema(cls): + return {"type": "object"} + + +class DummySQLAgent(BaseSQLAgent): + @property + def response_object(self) -> type: + return DummyResponse + + @property + def deployment_name(self) -> str: + return self.config.model_type.get(self.agent_type) + + +class FakeAgentModel: + def __init__(self): + self.name = "test-agent" + self.description = "test-description" + self.id = "agent-id" + self.instructions = "some instructions" + + +@pytest.fixture +def mock_config(): + mock = MagicMock() + mock.sql_to = "TSQL" + mock.sql_from = "MySQL" + mock.model_type = {AgentType.FIXER: "test-model"} + mock.ai_project_client.agents.create_agent = AsyncMock() + return mock + + +@pytest_asyncio.fixture +async def dummy_agent(mock_config): + return DummySQLAgent(agent_type=AgentType.FIXER, config=mock_config) + + +def test_properties(dummy_agent): + assert dummy_agent.agent_type == AgentType.FIXER + assert dummy_agent.config.sql_to == "TSQL" + assert dummy_agent.num_candidates is None + assert dummy_agent.plugins is None + assert dummy_agent.deployment_name == "test-model" + + +def test_get_kernel_arguments(dummy_agent): + args = dummy_agent.get_kernel_arguments() + assert isinstance(args, KernelArguments) + assert args["target"] == "TSQL" + assert args["source"] == "MySQL" + + +@pytest.mark.asyncio +async def test_setup_file_not_found(dummy_agent): + with patch("sql_agents.agents.agent_base.get_prompt", side_effect=FileNotFoundError): + with pytest.raises(ValueError, match="Prompt file for fixer not found."): + await dummy_agent.setup() + + +@pytest.mark.asyncio +async def test_get_agent_sets_up(dummy_agent): + dummy_agent.agent = None + + async def mock_setup(): + dummy_agent.agent = "mocked_agent" + + with patch.object(dummy_agent, "setup", new=AsyncMock(side_effect=mock_setup)) as mock_setup_fn, \ + patch("sql_agents.agents.agent_base.get_prompt", return_value="prompt content"): + + await dummy_agent.get_agent() + + mock_setup_fn.assert_awaited_once() + assert dummy_agent.agent == "mocked_agent" + + +@pytest.mark.asyncio +async def test_execute_invokes_agent(dummy_agent): + dummy_agent.agent = MagicMock() + dummy_agent.agent.invoke = AsyncMock(return_value={"result": "ok"}) + + result = await dummy_agent.execute("input query") + dummy_agent.agent.invoke.assert_awaited_once_with("input query") + assert result == {"result": "ok"} diff --git a/src/tests/backend/sql_agents/agents/agent_factory_test.py b/src/tests/backend/sql_agents/agents/agent_factory_test.py new file mode 100644 index 0000000..8b50b29 --- /dev/null +++ b/src/tests/backend/sql_agents/agents/agent_factory_test.py @@ -0,0 +1,66 @@ +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from sql_agents.agents.agent_base import BaseSQLAgent +from sql_agents.agents.agent_factory import SQLAgentFactory +from sql_agents.helpers.models import AgentType + + +# Mock agent class for registration test +class DummyAgent(BaseSQLAgent): + def __init__(self, **kwargs): + pass + + async def setup(self): + return "dummy-agent" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("agent_type", [ + AgentType.FIXER, + AgentType.MIGRATOR, + AgentType.PICKER, + AgentType.SEMANTIC_VERIFIER, + AgentType.SYNTAX_CHECKER, +]) +async def test_create_agent_success(agent_type): + mock_config = MagicMock() + + # Patch the actual agent class with a mock + mock_agent_class = MagicMock() + mock_agent_instance = MagicMock() + mock_agent_instance.setup = AsyncMock(return_value=f"{agent_type.value}-mock-agent") + mock_agent_class.return_value = mock_agent_instance + + SQLAgentFactory._agent_classes[agent_type] = mock_agent_class + + agent = await SQLAgentFactory.create_agent(agent_type, mock_config) + assert agent == f"{agent_type.value}-mock-agent" + mock_agent_class.assert_called_once() + mock_agent_instance.setup.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_create_agent_invalid_type(): + with pytest.raises(ValueError, match="Unknown agent type: dummy"): + await SQLAgentFactory.create_agent("dummy", MagicMock()) + + +def test_get_agent_class_success(): + for agent_type in SQLAgentFactory._agent_classes: + cls = SQLAgentFactory.get_agent_class(agent_type) + assert cls == SQLAgentFactory._agent_classes[agent_type] + + +def test_get_agent_class_failure(): + with pytest.raises(ValueError, match="Unknown agent type: dummy"): + SQLAgentFactory.get_agent_class("dummy") + + +# def test_register_agent_class(caplog): +# agent_type = "dummy_type" +# SQLAgentFactory.register_agent_class(agent_type, DummyAgent) + +# assert SQLAgentFactory._agent_classes[agent_type] == DummyAgent +# assert any("Registered agent class DummyAgent" in message for message in caplog.text.splitlines()) diff --git a/src/tests/backend/sql_agents/agents/fixer/fixer_agent_test.py b/src/tests/backend/sql_agents/agents/fixer/fixer_agent_test.py new file mode 100644 index 0000000..266ddb0 --- /dev/null +++ b/src/tests/backend/sql_agents/agents/fixer/fixer_agent_test.py @@ -0,0 +1,34 @@ +from unittest.mock import MagicMock + +import pytest + +from sql_agents.agents.fixer.agent import FixerAgent +from sql_agents.agents.fixer.response import FixerResponse +from sql_agents.helpers.models import AgentType + + +@pytest.fixture +def mock_config(): + """Fixture to mock the config for FixerAgent.""" + mock_config = MagicMock() + mock_config.model_type = { + AgentType.FIXER: "fixer_model_name" + } + return mock_config + + +@pytest.fixture +def fixer_agent(mock_config): + """Fixture to create an instance of FixerAgent with a mocked config.""" + agent = FixerAgent(config=mock_config, agent_type=AgentType.FIXER) + return agent + + +def test_response_object(fixer_agent): + """Test the response_object property.""" + assert fixer_agent.response_object == FixerResponse + + +def test_deployment_name(fixer_agent): + """Test the deployment_name property.""" + assert fixer_agent.deployment_name == "fixer_model_name" diff --git a/src/tests/backend/sql_agents/agents/fixer/fixer_response_test.py b/src/tests/backend/sql_agents/agents/fixer/fixer_response_test.py new file mode 100644 index 0000000..a4292d5 --- /dev/null +++ b/src/tests/backend/sql_agents/agents/fixer/fixer_response_test.py @@ -0,0 +1,41 @@ +from pydantic_core import ValidationError + +import pytest + +from sql_agents.agents.fixer.response import FixerResponse + + +def test_fixer_response_creation_all_fields(): + """Test creating FixerResponse with all fields.""" + response = FixerResponse( + thought="Analyzing query structure", + fixed_query="SELECT * FROM users", + summary="Corrected syntax error" + ) + + assert response.thought == "Analyzing query structure" + assert response.fixed_query == "SELECT * FROM users" + assert response.summary == "Corrected syntax error" + + +def test_fixer_response_creation_optional_summary(): + """Test creating FixerResponse without optional summary.""" + response = FixerResponse( + thought="Fix completed", + fixed_query="SELECT name FROM customers", + summary=None + ) + + assert response.thought == "Fix completed" + assert response.fixed_query == "SELECT name FROM customers" + assert response.summary is None + + +def test_fixer_response_invalid_field_types(): + """Test FixerResponse raises error for invalid field types.""" + with pytest.raises(ValidationError, match="Input should be a valid string"): + FixerResponse( + thought=123, # Invalid type + fixed_query=["SELECT * FROM orders"], # Invalid type + summary=456 # Should be str or None + ) diff --git a/src/tests/backend/sql_agents/agents/migrator/migrator_agent_test.py b/src/tests/backend/sql_agents/agents/migrator/migrator_agent_test.py new file mode 100644 index 0000000..4394db9 --- /dev/null +++ b/src/tests/backend/sql_agents/agents/migrator/migrator_agent_test.py @@ -0,0 +1,36 @@ +from unittest.mock import MagicMock + +import pytest + +from sql_agents.agents.migrator.agent import MigratorAgent +from sql_agents.agents.migrator.response import MigratorResponse +from sql_agents.helpers.models import AgentType + + +@pytest.fixture +def mock_config(): + mock = MagicMock() + mock.model_type = { + AgentType.MIGRATOR: "migrator-model-name" + } + return mock + + +@pytest.fixture +def migrator_agent(mock_config): + return MigratorAgent(config=mock_config, agent_type=AgentType.MIGRATOR) + + +def test_response_object(migrator_agent): + """Test that the response_object returns MigratorResponse.""" + assert migrator_agent.response_object is MigratorResponse + + +def test_num_candidates(migrator_agent): + """Test that the num_candidates property returns 3.""" + assert migrator_agent.num_candidates == 3 + + +def test_deployment_name(migrator_agent): + """Test that the correct model name is returned from config.""" + assert migrator_agent.deployment_name == "migrator-model-name" diff --git a/src/tests/backend/sql_agents/agents/migrator/migrator_response_test.py b/src/tests/backend/sql_agents/agents/migrator/migrator_response_test.py new file mode 100644 index 0000000..47872a1 --- /dev/null +++ b/src/tests/backend/sql_agents/agents/migrator/migrator_response_test.py @@ -0,0 +1,34 @@ +from sql_agents.agents.migrator.response import MigratorCandidate, MigratorResponse + + +def test_migrator_candidate_creation(): + """Test that MigratorCandidate can be created successfully.""" + candidate = MigratorCandidate(plan="Use LEFT JOIN", candidate_query="SELECT * FROM table1 LEFT JOIN table2") + assert candidate.plan == "Use LEFT JOIN" + assert candidate.candidate_query == "SELECT * FROM table1 LEFT JOIN table2" + + +def test_migrator_response_full(): + """Test full MigratorResponse with all fields populated.""" + candidate = MigratorCandidate(plan="Use JOIN", candidate_query="SELECT * FROM A JOIN B") + response = MigratorResponse( + input_summary="Translates query logic", + candidates=[candidate], + input_error="Syntax error in original query", + summary="Final version corrected", + rai_error="RAI flag triggered" + ) + assert response.input_summary == "Translates query logic" + assert len(response.candidates) == 1 + assert response.input_error == "Syntax error in original query" + assert response.summary == "Final version corrected" + assert response.rai_error == "RAI flag triggered" + + +def test_migrator_response_defaults(): + """Test MigratorResponse with only required fields.""" + candidate = MigratorCandidate(plan="Use EXISTS", candidate_query="SELECT ...") + response = MigratorResponse(input_summary="Check optimization", candidates=[candidate]) + assert response.input_error is None + assert response.summary is None + assert response.rai_error is None diff --git a/src/tests/backend/sql_agents/agents/picker/picker_agent_test.py b/src/tests/backend/sql_agents/agents/picker/picker_agent_test.py new file mode 100644 index 0000000..d7069c1 --- /dev/null +++ b/src/tests/backend/sql_agents/agents/picker/picker_agent_test.py @@ -0,0 +1,32 @@ +from unittest.mock import MagicMock + +import pytest + +from sql_agents.agents.picker.agent import PickerAgent +from sql_agents.agents.picker.response import PickerResponse +from sql_agents.helpers.models import AgentType + + +@pytest.fixture +def mock_config(): + return MagicMock(model_type={AgentType.PICKER: "picker-model-v1"}) + + +@pytest.fixture +def picker_agent(mock_config): + return PickerAgent(config=mock_config, agent_type=AgentType.PICKER) + + +def test_response_object(picker_agent): + """Test that the response_object property returns PickerResponse.""" + assert picker_agent.response_object is PickerResponse + + +def test_num_candidates(picker_agent): + """Test that the num_candidates property returns 3.""" + assert picker_agent.num_candidates == 3 + + +def test_deployment_name(picker_agent): + """Test that the deployment_name returns the correct model name.""" + assert picker_agent.deployment_name == "picker-model-v1" diff --git a/src/tests/backend/sql_agents/agents/picker/picker_response_test.py b/src/tests/backend/sql_agents/agents/picker/picker_response_test.py new file mode 100644 index 0000000..7036f47 --- /dev/null +++ b/src/tests/backend/sql_agents/agents/picker/picker_response_test.py @@ -0,0 +1,26 @@ +from sql_agents.agents.picker.response import PickerResponse + + +def test_picker_response_all_fields(): + """Test PickerResponse with all fields provided.""" + response = PickerResponse( + conclusion="Chosen candidate is accurate.", + picked_query="SELECT * FROM users", + summary="Summary of the selection process." + ) + + assert response.conclusion == "Chosen candidate is accurate." + assert response.picked_query == "SELECT * FROM users" + assert response.summary == "Summary of the selection process." + + +def test_picker_response_optional_summary_none(): + """Test PickerResponse when optional summary is None.""" + response = PickerResponse( + conclusion="No valid candidates.", + picked_query="SELECT * FROM fallback", + summary=None + ) + + assert response.summary is None + assert response.picked_query == "SELECT * FROM fallback" diff --git a/src/tests/backend/sql_agents/agents/semantic_verifier/semantic_verifier_agent_test.py b/src/tests/backend/sql_agents/agents/semantic_verifier/semantic_verifier_agent_test.py new file mode 100644 index 0000000..56c1166 --- /dev/null +++ b/src/tests/backend/sql_agents/agents/semantic_verifier/semantic_verifier_agent_test.py @@ -0,0 +1,48 @@ +from unittest.mock import MagicMock + +import pytest + +from sql_agents.agents.semantic_verifier.agent import SemanticVerifierAgent +from sql_agents.agents.semantic_verifier.response import SemanticVerifierResponse +from sql_agents.helpers.models import AgentType + + +@pytest.fixture +def mock_config(): + """Fixture to create a mock configuration.""" + mock_config = MagicMock() + mock_config.model_type = { + AgentType.SEMANTIC_VERIFIER: "semantic_verifier_model" + } + return mock_config + + +@pytest.fixture +def semantic_verifier_agent(mock_config): + """Fixture to create a SemanticVerifierAgent instance.""" + agent = SemanticVerifierAgent( + agent_type=AgentType.SEMANTIC_VERIFIER, + config=mock_config + ) + return agent + + +def test_response_object(semantic_verifier_agent): + """Test that the response_object property returns SemanticVerifierResponse.""" + assert semantic_verifier_agent.response_object == SemanticVerifierResponse + + +def test_deployment_name(semantic_verifier_agent): + """Test that the deployment_name property returns the correct model name.""" + assert semantic_verifier_agent.deployment_name == "semantic_verifier_model" + + +def test_missing_deployment_name(mock_config): + """Test that accessing deployment_name raises a KeyError if the model type is missing.""" + mock_config.model_type = {} + agent = SemanticVerifierAgent( + agent_type=AgentType.SEMANTIC_VERIFIER, + config=mock_config + ) + with pytest.raises(KeyError): + _ = agent.deployment_name diff --git a/src/tests/backend/sql_agents/agents/semantic_verifier/semantic_verifier_response_test.py b/src/tests/backend/sql_agents/agents/semantic_verifier/semantic_verifier_response_test.py new file mode 100644 index 0000000..4e2370d --- /dev/null +++ b/src/tests/backend/sql_agents/agents/semantic_verifier/semantic_verifier_response_test.py @@ -0,0 +1,62 @@ +import pytest + +from sql_agents.agents.semantic_verifier.response import SemanticVerifierResponse + + +def test_semantic_verifier_response_initialization(): + """Test initializing SemanticVerifierResponse with valid data.""" + response = SemanticVerifierResponse( + judgement="valid", + differences=["difference1", "difference2"], + summary="This is a summary." + ) + assert response.judgement == "valid" + assert response.differences == ["difference1", "difference2"] + assert response.summary == "This is a summary." + + +def test_semantic_verifier_response_empty_fields(): + """Test initializing SemanticVerifierResponse with empty fields.""" + response = SemanticVerifierResponse( + judgement="", + differences=[], + summary="" + ) + assert response.judgement == "" + assert response.differences == [] + assert response.summary == "" + + +def test_semantic_verifier_response_invalid_data(): + """Test initializing SemanticVerifierResponse with invalid data.""" + with pytest.raises(ValueError): + SemanticVerifierResponse( + judgement=123, # Invalid type + differences="not a list", # Invalid type + summary=None # Invalid type + ) + + +def test_semantic_verifier_response_large_differences(): + """Test initializing SemanticVerifierResponse with a large number of differences.""" + differences = [f"difference{i}" for i in range(1000)] # Large list of differences + response = SemanticVerifierResponse( + judgement="valid", + differences=differences, + summary="This is a summary." + ) + assert len(response.differences) == 1000 + assert response.judgement == "valid" + assert response.summary == "This is a summary." + + +def test_semantic_verifier_response_special_characters(): + """Test initializing SemanticVerifierResponse with special characters.""" + response = SemanticVerifierResponse( + judgement="valid!@#$%^&*()", + differences=["difference1", "difference2"], + summary="This is a summary with special characters!@#$%^&*()" + ) + assert response.judgement == "valid!@#$%^&*()" + assert response.differences == ["difference1", "difference2"] + assert response.summary == "This is a summary with special characters!@#$%^&*()" diff --git a/src/tests/backend/sql_agents/agents/syntax_checker/syntax_checker_agent_test.py b/src/tests/backend/sql_agents/agents/syntax_checker/syntax_checker_agent_test.py new file mode 100644 index 0000000..5494a12 --- /dev/null +++ b/src/tests/backend/sql_agents/agents/syntax_checker/syntax_checker_agent_test.py @@ -0,0 +1,57 @@ +from unittest.mock import MagicMock + +import pytest + +from sql_agents.agents.syntax_checker.agent import SyntaxCheckerAgent +from sql_agents.agents.syntax_checker.plug_ins import SyntaxCheckerPlugin +from sql_agents.agents.syntax_checker.response import SyntaxCheckerResponse +from sql_agents.helpers.models import AgentType + + +@pytest.fixture +def mock_config(): + """Fixture to create a mock configuration.""" + mock_config = MagicMock() + mock_config.model_type = { + AgentType.SYNTAX_CHECKER: "syntax_checker_model" + } + return mock_config + + +@pytest.fixture +def syntax_checker_agent(mock_config): + """Fixture to create a SyntaxCheckerAgent instance.""" + agent = SyntaxCheckerAgent( + agent_type=AgentType.SYNTAX_CHECKER, + config=mock_config + ) + return agent + + +def test_response_object(syntax_checker_agent): + """Test that the response_object property returns SyntaxCheckerResponse.""" + assert syntax_checker_agent.response_object == SyntaxCheckerResponse + + +def test_plugins(syntax_checker_agent): + """Test that the plugins property returns the correct plugins.""" + plugins = syntax_checker_agent.plugins + assert isinstance(plugins, list) + assert plugins[0] == "check_syntax" + assert isinstance(plugins[1], SyntaxCheckerPlugin) + + +def test_deployment_name(syntax_checker_agent): + """Test that the deployment_name property returns the correct model name.""" + assert syntax_checker_agent.deployment_name == "syntax_checker_model" + + +def test_missing_deployment_name(mock_config): + """Test that accessing deployment_name raises a KeyError if the model type is missing.""" + mock_config.model_type = {} # Simulate missing AgentType in model_type + agent = SyntaxCheckerAgent( + agent_type=AgentType.SYNTAX_CHECKER, + config=mock_config + ) + with pytest.raises(KeyError): + _ = agent.deployment_name diff --git a/src/tests/backend/sql_agents/agents/syntax_checker/syntax_checker_plug_ins_test.py b/src/tests/backend/sql_agents/agents/syntax_checker/syntax_checker_plug_ins_test.py new file mode 100644 index 0000000..d076139 --- /dev/null +++ b/src/tests/backend/sql_agents/agents/syntax_checker/syntax_checker_plug_ins_test.py @@ -0,0 +1,171 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from sql_agents.agents.syntax_checker.plug_ins import SyntaxCheckerPlugin + + +@pytest.fixture +def syntax_checker_plugin(): + """Fixture to create a SyntaxCheckerPlugin instance.""" + return SyntaxCheckerPlugin() + + +@patch("sql_agents.agents.syntax_checker.plug_ins.subprocess.run") +def test_check_syntax_windows_path(mock_subprocess_run, syntax_checker_plugin): + """Test the _call_tsqlparser method on Windows.""" + with patch("platform.system", return_value="Windows"): + mock_subprocess_run.return_value = MagicMock(stdout="[]") + candidate_sql = "SELECT * FROM table" + result = syntax_checker_plugin.check_syntax(candidate_sql) + assert result == "[]" + mock_subprocess_run.assert_called_once_with( + [r".\sql_agents\tools\win-x64\tsqlParser.exe", "--string", candidate_sql], + capture_output=True, + text=True, + check=True, + ) + + +@patch("sql_agents.agents.syntax_checker.plug_ins.subprocess.run") +def test_check_syntax_linux_path(mock_subprocess_run, syntax_checker_plugin): + """Test the _call_tsqlparser method on Linux.""" + with patch("platform.system", return_value="Linux"): + mock_subprocess_run.return_value = MagicMock(stdout="[]") + candidate_sql = "SELECT * FROM table" + result = syntax_checker_plugin.check_syntax(candidate_sql) + assert result == "[]" + mock_subprocess_run.assert_called_once_with( + ["./sql_agents/tools/linux-x64/tsqlParser", "--string", candidate_sql], + capture_output=True, + text=True, + check=True, + ) + + +@patch("sql_agents.agents.syntax_checker.plug_ins.subprocess.run") +def test_check_syntax_other_os(mock_subprocess_run, syntax_checker_plugin): + """Test the _call_tsqlparser method on other OS.""" + with patch("platform.system", return_value="Other"): + mock_subprocess_run.return_value = MagicMock(stdout="[]") + candidate_sql = "SELECT * FROM table" + result = syntax_checker_plugin.check_syntax(candidate_sql) + assert result == "[]" + mock_subprocess_run.assert_called_once_with( + ["./sql_agents/tools/linux-x64/tsqlParser", "--string", candidate_sql], + capture_output=True, + text=True, + check=True, + ) + + +@patch("sql_agents.agents.syntax_checker.plug_ins.subprocess.run") +def test_check_syntax_empty_string(mock_subprocess_run, syntax_checker_plugin): + """Test the _call_tsqlparser method with an empty string.""" + with patch("platform.system", return_value="Windows"): + mock_subprocess_run.return_value = MagicMock(stdout="[]") + candidate_sql = "" + result = syntax_checker_plugin.check_syntax(candidate_sql) + assert result == "[]" + mock_subprocess_run.assert_called_once_with( + [r".\sql_agents\tools\win-x64\tsqlParser.exe", "--string", candidate_sql], + capture_output=True, + text=True, + check=True, + ) + + +@patch("sql_agents.agents.syntax_checker.plug_ins.subprocess.run") +def test_check_syntax_empty_string_linux(mock_subprocess_run, syntax_checker_plugin): + """Test the _call_tsqlparser method with an empty string.""" + with patch("platform.system", return_value="Linux"): + mock_subprocess_run.return_value = MagicMock(stdout="[]") + candidate_sql = "" + result = syntax_checker_plugin.check_syntax(candidate_sql) + assert result == "[]" + mock_subprocess_run.assert_called_once_with( + ["./sql_agents/tools/linux-x64/tsqlParser", "--string", candidate_sql], + capture_output=True, + text=True, + check=True, + ) + + +@patch("sql_agents.agents.syntax_checker.plug_ins.subprocess.run") +def test_check_syntax_empty_string_other_os(mock_subprocess_run, syntax_checker_plugin): + """Test the _call_tsqlparser method with an empty string.""" + with patch("platform.system", return_value="Other"): + mock_subprocess_run.return_value = MagicMock(stdout="[]") + candidate_sql = "" + result = syntax_checker_plugin.check_syntax(candidate_sql) + assert result == "[]" + mock_subprocess_run.assert_called_once_with( + ["./sql_agents/tools/linux-x64/tsqlParser", "--string", candidate_sql], + capture_output=True, + text=True, + check=True, + ) + + +@patch("sql_agents.agents.syntax_checker.plug_ins.subprocess.run") +def test_check_syntax_invalid_sql(mock_subprocess_run, syntax_checker_plugin): + """Test the _call_tsqlparser method with invalid SQL.""" + with patch("platform.system", return_value="Windows"): + mock_subprocess_run.return_value = MagicMock(stdout="[]") + candidate_sql = "INVALID SQL" + result = syntax_checker_plugin.check_syntax(candidate_sql) + assert result == "[]" + mock_subprocess_run.assert_called_once_with( + [r".\sql_agents\tools\win-x64\tsqlParser.exe", "--string", candidate_sql], + capture_output=True, + text=True, + check=True, + ) + + +@patch("sql_agents.agents.syntax_checker.plug_ins.subprocess.run") +def test_check_syntax_invalid_sql_linux(mock_subprocess_run, syntax_checker_plugin): + """Test the _call_tsqlparser method with invalid SQL.""" + with patch("platform.system", return_value="Linux"): + mock_subprocess_run.return_value = MagicMock(stdout="[]") + candidate_sql = "INVALID SQL" + result = syntax_checker_plugin.check_syntax(candidate_sql) + assert result == "[]" + mock_subprocess_run.assert_called_once_with( + ["./sql_agents/tools/linux-x64/tsqlParser", "--string", candidate_sql], + capture_output=True, + text=True, + check=True, + ) + + +@patch("sql_agents.agents.syntax_checker.plug_ins.subprocess.run") +def test_check_syntax_invalid_sql_other_os(mock_subprocess_run, syntax_checker_plugin): + """Test the _call_tsqlparser method with invalid SQL.""" + with patch("platform.system", return_value="Other"): + mock_subprocess_run.return_value = MagicMock(stdout="[]") + candidate_sql = "INVALID SQL" + result = syntax_checker_plugin.check_syntax(candidate_sql) + assert result == "[]" + mock_subprocess_run.assert_called_once_with( + ["./sql_agents/tools/linux-x64/tsqlParser", "--string", candidate_sql], + capture_output=True, + text=True, + check=True, + ) + + +@patch("sql_agents.agents.syntax_checker.plug_ins.subprocess.run") +def test_check_syntax_valid_sql(mock_subprocess_run, syntax_checker_plugin): + """Test the _call_tsqlparser method with valid SQL.""" + with patch("platform.system", return_value="Windows"): + mock_subprocess_run.return_value = MagicMock(stdout="[]") + candidate_sql = "SELECT * FROM table" + result = syntax_checker_plugin.check_syntax(candidate_sql) + assert result == "[]" + mock_subprocess_run.assert_called_once_with( + [r".\sql_agents\tools\win-x64\tsqlParser.exe", "--string", candidate_sql], + capture_output=True, + text=True, + check=True, + ) diff --git a/src/tests/backend/sql_agents/agents/syntax_checker/syntax_checker_response_test.py b/src/tests/backend/sql_agents/agents/syntax_checker/syntax_checker_response_test.py new file mode 100644 index 0000000..f79769a --- /dev/null +++ b/src/tests/backend/sql_agents/agents/syntax_checker/syntax_checker_response_test.py @@ -0,0 +1,37 @@ +from sql_agents.agents.syntax_checker.response import SyntaxCheckerResponse, SyntaxErrorInt + + +def test_syntax_error_int_initialization(): + """Test initializing SyntaxErrorInt with valid data.""" + syntax_error = SyntaxErrorInt(line=1, column=5, error="Syntax error") + assert syntax_error.line == 1 + assert syntax_error.column == 5 + assert syntax_error.error == "Syntax error" + + +def test_syntax_checker_response_initialization(): + """Test initializing SyntaxCheckerResponse with valid data.""" + syntax_error = SyntaxErrorInt(line=1, column=5, error="Syntax error") + response = SyntaxCheckerResponse( + thought="Analyzing SQL query", + syntax_errors=[syntax_error], + summary="1 syntax error found" + ) + assert response.thought == "Analyzing SQL query" + assert len(response.syntax_errors) == 1 + assert response.syntax_errors[0].line == 1 + assert response.syntax_errors[0].column == 5 + assert response.syntax_errors[0].error == "Syntax error" + assert response.summary == "1 syntax error found" + + +def test_syntax_checker_response_empty_fields(): + """Test initializing SyntaxCheckerResponse with empty fields.""" + response = SyntaxCheckerResponse( + thought="", + syntax_errors=[], + summary="" + ) + assert response.thought == "" + assert response.syntax_errors == [] + assert response.summary == "" diff --git a/src/tests/backend/sql_agents/convert_script_test.py b/src/tests/backend/sql_agents/convert_script_test.py new file mode 100644 index 0000000..224f03e --- /dev/null +++ b/src/tests/backend/sql_agents/convert_script_test.py @@ -0,0 +1,101 @@ +import datetime +import json +from unittest.mock import AsyncMock, MagicMock, patch + +from common.models.api import FileRecord, ProcessStatus + +import pytest + +from semantic_kernel.agents import Agent +from semantic_kernel.contents import AuthorRole +from semantic_kernel.contents import ChatMessageContent + +from sql_agents.convert_script import validate_migration + + +class DummyAgent(Agent): + async def invoke(self, *args, **kwargs): + return "dummy response" + + async def invoke_stream(self, *args, **kwargs): + yield "dummy stream" + + async def get_response(self, *args, **kwargs): + return "dummy response" + + +@pytest.fixture +def file_record(): + return FileRecord( + batch_id="batch-123", + file_id="file-456", + original_name="test.sql", + blob_path="path/to/blob.sql", + translated_path="path/to/translated.sql", + status=ProcessStatus.READY_TO_PROCESS, + file_result=None, + syntax_count=0, + error_count=0, + created_at=datetime.datetime.utcnow(), + updated_at=datetime.datetime.utcnow(), + ) + + +@pytest.fixture +def mock_batch_service(): + service = MagicMock() + service.create_file_log = AsyncMock() + return service + + +@pytest.fixture +def mock_sql_agents(): + return MagicMock(idx_agents=["picker", "migrator", "syntax", "fixer", "verifier"]) + + +@pytest.fixture +def dummy_response_factory(): + def create_response(name, role, content): + return ChatMessageContent( + name=name, + role=role, + content=content + ) + return create_response + + +@pytest.mark.asyncio +@patch("sql_agents.convert_script.send_status_update") +async def test_validate_migration_success(mock_status, file_record, mock_batch_service): + chat_response = ChatMessageContent(name="picker", role="assistant", content="summary") + result = await validate_migration("SELECT * FROM valid;", chat_response, file_record, mock_batch_service) + assert result is True + assert mock_batch_service.create_file_log.await_count == 1 + + +@pytest.mark.asyncio +@patch("sql_agents.convert_script.send_status_update") +async def test_validate_migration_failure(mock_status, file_record, mock_batch_service): + result = await validate_migration("", None, file_record, mock_batch_service) + assert result is False + assert mock_batch_service.create_file_log.await_count == 1 + + +# Helper for async for loop +async def async_generator(responses): + for r in responses: + yield r + + +def dummy_response(name, content, role=AuthorRole.ASSISTANT.value): + response = MagicMock() + response.name = name + response.role = role + response.content = json.dumps(content) + return response + + +# Async generator utility +async def async_gen(responses): + for res in responses: + yield res diff --git a/src/tests/backend/sql_agents/helpers/agents_manager_test.py b/src/tests/backend/sql_agents/helpers/agents_manager_test.py new file mode 100644 index 0000000..489546c --- /dev/null +++ b/src/tests/backend/sql_agents/helpers/agents_manager_test.py @@ -0,0 +1,92 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from sql_agents.agents.agent_config import AgentBaseConfig +from sql_agents.helpers.agents_manager import SqlAgents +from sql_agents.helpers.models import AgentType + + +@pytest.mark.asyncio +async def test_create_sql_agents_success(): + config = MagicMock(spec=AgentBaseConfig) + + with patch("sql_agents.helpers.agents_manager.setup_fixer_agent", new_callable=AsyncMock) as mock_fixer, \ + patch("sql_agents.helpers.agents_manager.setup_migrator_agent", new_callable=AsyncMock) as mock_migrator, \ + patch("sql_agents.helpers.agents_manager.setup_picker_agent", new_callable=AsyncMock) as mock_picker, \ + patch("sql_agents.helpers.agents_manager.setup_syntax_checker_agent", new_callable=AsyncMock) as mock_syntax, \ + patch("sql_agents.helpers.agents_manager.setup_semantic_verifier_agent", new_callable=AsyncMock) as mock_semantic: + + # Setup mock return values + mock_fixer.return_value.id = "fixer-id" + mock_migrator.return_value.id = "migrator-id" + mock_picker.return_value.id = "picker-id" + mock_syntax.return_value.id = "syntax-id" + mock_semantic.return_value.id = "semantic-id" + + agents = await SqlAgents.create(config) + + assert agents.agent_config == config + assert agents.agent_fixer.id == "fixer-id" + assert agents.agent_migrator.id == "migrator-id" + assert agents.agent_picker.id == "picker-id" + assert agents.agent_syntax_checker.id == "syntax-id" + assert agents.agent_semantic_verifier.id == "semantic-id" + + assert len(agents.agents) == 5 + assert agents.idx_agents[AgentType.MIGRATOR].id == "migrator-id" + + +@pytest.mark.asyncio +async def test_create_sql_agents_failure(): + config = MagicMock(spec=AgentBaseConfig) + + with patch("sql_agents.helpers.agents_manager.setup_fixer_agent", new_callable=AsyncMock) as mock_fixer: + mock_fixer.side_effect = ValueError("Failed to create fixer") + + with pytest.raises(ValueError, match="Failed to create fixer"): + await SqlAgents.create(config) + + +@pytest.mark.asyncio +async def test_delete_agents_success(): + # Create a dummy agent with id + agent_mock = MagicMock() + agent_mock.id = "agent-id" + + config = MagicMock() + config.ai_project_client.agents.delete_agent = AsyncMock() + + agents = SqlAgents() + agents.agent_config = config + agents.agent_migrator = agent_mock + agents.agent_picker = agent_mock + agents.agent_syntax_checker = agent_mock + agents.agent_fixer = agent_mock + agents.agent_semantic_verifier = agent_mock + + await agents.delete_agents() + + assert config.ai_project_client.agents.delete_agent.await_count == 5 + config.ai_project_client.agents.delete_agent.assert_called_with("agent-id") + + +@pytest.mark.asyncio +async def test_delete_agents_with_exception(caplog): + agent_mock = MagicMock() + agent_mock.id = "agent-id" + + config = MagicMock() + config.ai_project_client.agents.delete_agent = AsyncMock(side_effect=Exception("delete failed")) + + agents = SqlAgents() + agents.agent_config = config + agents.agent_migrator = agent_mock + agents.agent_picker = agent_mock + agents.agent_syntax_checker = agent_mock + agents.agent_fixer = agent_mock + agents.agent_semantic_verifier = agent_mock + + await agents.delete_agents() + + assert "Error deleting agents: delete failed" in caplog.text diff --git a/src/tests/backend/sql_agents/helpers/comms_manager_test.py b/src/tests/backend/sql_agents/helpers/comms_manager_test.py new file mode 100644 index 0000000..c059865 --- /dev/null +++ b/src/tests/backend/sql_agents/helpers/comms_manager_test.py @@ -0,0 +1,76 @@ +from unittest.mock import MagicMock + +import pytest + +from semantic_kernel.agents import Agent + +from sql_agents.helpers.comms_manager import CommsManager +from sql_agents.helpers.models import AgentType + + +class DummyAgent: + def __init__(self, name): + self.name = name + + +class DummyHistory: + def __init__(self, name, content): + self.name = name + self.content = content + + +def mock_agent(name: str) -> Agent: + agent = MagicMock(spec=Agent) + agent.name = name + return agent + + +@pytest.fixture +def agents(): + return { + AgentType.MIGRATOR: mock_agent("migrator"), + AgentType.PICKER: mock_agent("picker"), + AgentType.SYNTAX_CHECKER: mock_agent("syntax_checker"), + AgentType.FIXER: mock_agent("fixer"), + AgentType.SEMANTIC_VERIFIER: mock_agent("semantic_verifier"), + } + + +@pytest.mark.asyncio +@pytest.mark.parametrize("last_agent, expected_next_agent", [ + (AgentType.MIGRATOR.value, AgentType.PICKER.value), + (AgentType.PICKER.value, AgentType.SYNTAX_CHECKER.value), + (AgentType.SYNTAX_CHECKER.value, AgentType.FIXER.value), + (AgentType.FIXER.value, AgentType.SYNTAX_CHECKER.value), + ("candidate", AgentType.SEMANTIC_VERIFIER.value), + ("unknown", AgentType.MIGRATOR.value), +]) +async def test_selection_strategy_select_agent(last_agent, expected_next_agent, agents): + strategy = CommsManager.SelectionStrategy(agents=agents.values()) + history = [DummyHistory(last_agent, "")] # dummy history item + result = await strategy.select_agent(list(agents.values()), history) + assert result.name == expected_next_agent + + +@pytest.mark.asyncio +async def test_should_agent_terminate_semantic_verifier(agents): + strategy = CommsManager.ApprovalTerminationStrategy( + agents=[agents[AgentType.MIGRATOR], agents[AgentType.SEMANTIC_VERIFIER]], + maximum_iterations=10, + automatic_reset=True, + ) + history = [DummyHistory(AgentType.SEMANTIC_VERIFIER.value, "content")] + terminate = await strategy.should_agent_terminate(agents[AgentType.SEMANTIC_VERIFIER], history) + assert terminate is True + + +@pytest.mark.asyncio +async def test_should_agent_terminate_other_agents(agents): + strategy = CommsManager.ApprovalTerminationStrategy( + agents=[agents[AgentType.MIGRATOR], agents[AgentType.SEMANTIC_VERIFIER]], + maximum_iterations=10, + automatic_reset=True, + ) + history = [DummyHistory(AgentType.SYNTAX_CHECKER.value, "content")] + terminate = await strategy.should_agent_terminate(agents[AgentType.SYNTAX_CHECKER], history) + assert terminate is False diff --git a/src/tests/backend/sql_agents/helpers/utils_test.py b/src/tests/backend/sql_agents/helpers/utils_test.py new file mode 100644 index 0000000..6955ec6 --- /dev/null +++ b/src/tests/backend/sql_agents/helpers/utils_test.py @@ -0,0 +1,45 @@ +from unittest import mock + +import pytest + +from sql_agents.helpers.utils import get_prompt, is_text + + +def test_get_prompt_valid_agent_type(): + agent_type = "agent1" + + # Mock the file reading + with mock.patch("builtins.open", mock.mock_open(read_data="This is the prompt")): + prompt = get_prompt(agent_type) + + # Assert the prompt returned is correct + assert prompt == "This is the prompt" + + +def test_get_prompt_invalid_agent_type(): + agent_type = "invalid-agent!" # Invalid agent type with a non-alphanumeric character + + # Expect a ValueError to be raised for an invalid agent type + with pytest.raises(ValueError): + get_prompt(agent_type) + + +def test_get_prompt_file_not_found(): + agent_type = "agent1" + + # Mock os.path.join and the file not being found + with mock.patch("builtins.open", mock.mock_open()) as mock_file: + mock_file.side_effect = FileNotFoundError + with pytest.raises(FileNotFoundError): + get_prompt(agent_type) + + +# Test for the is_text function +def test_is_text_empty_string(): + """Test when the content is an empty string.""" + assert not is_text("") + + +def test_is_text_non_empty_string(): + """Test when the content is a non-empty string.""" + assert is_text("Hello, world!") diff --git a/src/tests/backend/sql_agents/process_batch_test.py b/src/tests/backend/sql_agents/process_batch_test.py new file mode 100644 index 0000000..8f0e936 --- /dev/null +++ b/src/tests/backend/sql_agents/process_batch_test.py @@ -0,0 +1,420 @@ +import datetime +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import UUID + +from azure.core.exceptions import ServiceResponseError as ServiceResponseException + +from common.models.api import FileRecord, FileResult, LogType, ProcessStatus + +import pytest + +from semantic_kernel.contents import AuthorRole + +from sql_agents.helpers.models import AgentType +from sql_agents.process_batch import add_rai_disclaimer, process_batch_async, process_error + + +@pytest.mark.asyncio +@patch("sql_agents.process_batch.add_rai_disclaimer", return_value="SELECT * FROM converted;") +@patch("sql_agents.process_batch.process_error", new_callable=AsyncMock) +@patch("sql_agents.process_batch.is_text") +@patch("sql_agents.process_batch.send_status_update") +@patch("sql_agents.process_batch.convert_script", new_callable=AsyncMock) +@patch("sql_agents.process_batch.SqlAgents.create", new_callable=AsyncMock) +@patch("sql_agents.process_batch.AzureAIAgent.create_client") +@patch("sql_agents.process_batch.DefaultAzureCredential") +@patch("sql_agents.process_batch.BatchService") +@patch("sql_agents.process_batch.BlobStorageFactory.get_storage", new_callable=AsyncMock) +async def test_process_batch_async_success( + mock_get_storage, + mock_batch_service_cls, + mock_creds_cls, + mock_create_client, + mock_sql_agents_create, + mock_convert_script, + mock_send_status_update, + mock_is_text, + mock_process_error, + mock_add_disclaimer, +): + # UUID and timestamps for mocks + file_id = str(UUID(int=0)) + now = datetime.datetime.utcnow() + + # Mock file dict (simulates DB record) + file_dict = { + "file_id": file_id, + "blob_path": "blob/path", + "batch_id": "1", + "original_name": "file.sql", + "translated_path": "translated.sql", + "status": ProcessStatus.READY_TO_PROCESS.value, + "error_count": 0, + "syntax_count": 0, + "created_at": now, + "updated_at": now, + } + + # Setup BlobStorage mock + mock_storage = AsyncMock() + mock_storage.get_file.return_value = "SELECT * FROM dummy;" + mock_get_storage.return_value = mock_storage + + # Setup BatchService mock + mock_batch_service = AsyncMock() + mock_batch_service.database.get_batch_files.return_value = [file_dict] + mock_batch_service.initialize_database.return_value = None + mock_batch_service.update_batch.return_value = None + mock_batch_service.update_file_record.return_value = None + mock_batch_service.create_candidate.return_value = None + mock_batch_service.update_file_counts.return_value = None + mock_batch_service.batch_files_final_update.return_value = None + mock_batch_service.create_file_log.return_value = None + mock_batch_service_cls.return_value = mock_batch_service + + # Mock DefaultAzureCredential async context manager + mock_creds = MagicMock() + mock_creds.__aenter__.return_value = mock_creds + mock_creds.__aexit__.return_value = None + mock_creds_cls.return_value = mock_creds + + # Mock AzureAIAgent.create_client async context manager + mock_client = MagicMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + mock_create_client.return_value = mock_client + + # Mock SqlAgents.create and its delete_agents + mock_agents = AsyncMock() + mock_agents.delete_agents.return_value = None + mock_sql_agents_create.return_value = mock_agents + + # Successful text file path + mock_is_text.return_value = True + mock_convert_script.return_value = "SELECT * FROM converted;" + + # Mock the FileRecord.fromdb method to return a mock file_record + mock_file_record = MagicMock() + mock_file_record.status = ProcessStatus.READY_TO_PROCESS + mock_file_record.file_id = file_id + mock_file_record.blob_path = file_dict["blob_path"] + mock_file_record.batch_id = file_dict["batch_id"] + mock_file_record.file_result = FileResult.SUCCESS + mock_file_record.error_count = 0 + + # Patch FileRecord to return the mock + with patch("sql_agents.process_batch.FileRecord.fromdb", return_value=mock_file_record): + # Execute function + await process_batch_async("1", "informix", "tsql") + + # Assertions + mock_batch_service.update_file_record.assert_called_with(mock_file_record) + # mock_send_status_update.assert_called_once() + mock_process_error.assert_not_called() + + +@pytest.mark.asyncio +@patch("sql_agents.process_batch.add_rai_disclaimer", return_value="SELECT * FROM converted;") +@patch("sql_agents.process_batch.process_error", new_callable=AsyncMock) +@patch("sql_agents.process_batch.is_text") +@patch("sql_agents.process_batch.send_status_update") +@patch("sql_agents.process_batch.convert_script", new_callable=AsyncMock) +@patch("sql_agents.process_batch.SqlAgents.create", new_callable=AsyncMock) +@patch("sql_agents.process_batch.AzureAIAgent.create_client") +@patch("sql_agents.process_batch.DefaultAzureCredential") +@patch("sql_agents.process_batch.BatchService") +@patch("sql_agents.process_batch.BlobStorageFactory.get_storage", new_callable=AsyncMock) +async def test_process_batch_async_invalid_text_file( + mock_get_storage, + mock_batch_service_cls, + mock_creds_cls, + mock_create_client, + mock_sql_agents_create, + mock_convert_script, + mock_send_status_update, + mock_is_text, + mock_process_error, + mock_add_disclaimer, +): + # UUID and timestamps for mocks + file_id = str(UUID(int=1)) + now = datetime.datetime.utcnow() + + file_dict = { + "file_id": file_id, + "blob_path": "invalid/blob/path.sql", + "batch_id": "1", + "original_name": "bad.sql", + "translated_path": "translated.sql", + "status": ProcessStatus.READY_TO_PROCESS.value, + "error_count": 0, + "syntax_count": 0, + "created_at": now, + "updated_at": now, + } + + # Mock the BlobStorage get_file to return dummy content + mock_storage = AsyncMock() + mock_storage.get_file.return_value = "binary content" + mock_get_storage.return_value = mock_storage + + # Setup BatchService mock + mock_batch_service = AsyncMock() + mock_batch_service.database.get_batch_files.return_value = [file_dict] + mock_batch_service.initialize_database.return_value = None + mock_batch_service.update_batch.return_value = None + mock_batch_service.update_file_record.return_value = None + mock_batch_service.create_file_log.return_value = None + mock_batch_service_cls.return_value = mock_batch_service + + # Setup Azure credential and client mocks + mock_creds = MagicMock() + mock_creds.__aenter__.return_value = mock_creds + mock_creds.__aexit__.return_value = None + mock_creds_cls.return_value = mock_creds + + mock_client = MagicMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + mock_create_client.return_value = mock_client + + mock_sql_agents = AsyncMock() + mock_sql_agents.delete_agents.return_value = None + mock_sql_agents_create.return_value = mock_sql_agents + + # Simulate non-text file + mock_is_text.return_value = False + + # Patch FileRecord.fromdb to return a mock + mock_file_record = MagicMock() + mock_file_record.file_id = file_id + mock_file_record.batch_id = file_dict["batch_id"] + mock_file_record.blob_path = file_dict["blob_path"] + mock_file_record.status = ProcessStatus.READY_TO_PROCESS + mock_file_record.file_result = FileResult.SUCCESS + mock_file_record.error_count = 0 + + with patch("sql_agents.process_batch.FileRecord.fromdb", return_value=mock_file_record): + await process_batch_async("1", "informix", "tsql") + + # Assertions: ensure correct flow for invalid text file + mock_is_text.assert_called_once() + mock_batch_service.create_file_log.assert_called_once_with( + str(file_id), + "File is not a valid text file. Skipping.", + "", + LogType.ERROR, + AgentType.ALL, + AuthorRole.ASSISTANT, + ) + mock_send_status_update.assert_called_once() + mock_batch_service.update_file_record.assert_called_with(mock_file_record) + assert mock_file_record.status == ProcessStatus.COMPLETED + assert mock_file_record.file_result == FileResult.ERROR + assert mock_file_record.error_count == 1 + + +@pytest.mark.asyncio +@patch("sql_agents.process_batch.add_rai_disclaimer", return_value="SELECT * FROM converted;") +@patch("sql_agents.process_batch.process_error", new_callable=AsyncMock) +@patch("sql_agents.process_batch.is_text") +@patch("sql_agents.process_batch.send_status_update") +@patch("sql_agents.process_batch.convert_script", new_callable=AsyncMock) +@patch("sql_agents.process_batch.SqlAgents.create", new_callable=AsyncMock) +@patch("sql_agents.process_batch.AzureAIAgent.create_client") +@patch("sql_agents.process_batch.DefaultAzureCredential") +@patch("sql_agents.process_batch.BatchService") +@patch("sql_agents.process_batch.BlobStorageFactory.get_storage", new_callable=AsyncMock) +async def test_process_batch_unicode_decode_error( + mock_get_storage, + mock_batch_service_cls, + mock_creds_cls, + mock_create_client, + mock_sql_agents_create, + mock_convert_script, + mock_send_status_update, + mock_is_text, + mock_process_error, + mock_add_disclaimer, +): + file_id = str(UUID(int=2)) + now = datetime.datetime.utcnow() + file_dict = { + "file_id": file_id, + "blob_path": "bad/path.sql", + "batch_id": "2", + "original_name": "bad.sql", + "translated_path": "converted.sql", + "status": ProcessStatus.READY_TO_PROCESS.value, + "error_count": 0, + "syntax_count": 0, + "created_at": now, + "updated_at": now, + } + + # Trigger UnicodeDecodeError on get_file + mock_storage = AsyncMock() + mock_storage.get_file.side_effect = UnicodeDecodeError("utf-8", b"", 0, 1, "error") + mock_get_storage.return_value = mock_storage + + mock_batch_service = AsyncMock() + mock_batch_service.database.get_batch_files.return_value = [file_dict] + mock_batch_service_cls.return_value = mock_batch_service + + mock_creds_cls.return_value.__aenter__.return_value = MagicMock() + mock_create_client.return_value.__aenter__.return_value = MagicMock() + mock_sql_agents_create.return_value = AsyncMock() + + mock_file_record = MagicMock() + mock_file_record.file_id = file_id + mock_file_record.batch_id = "2" + mock_file_record.blob_path = "bad/path.sql" + with patch("sql_agents.process_batch.FileRecord.fromdb", return_value=mock_file_record): + await process_batch_async("2", "informix", "tsql") + + mock_process_error.assert_called_once() + + +@pytest.mark.asyncio +@patch("sql_agents.process_batch.process_error", new_callable=AsyncMock) +@patch("sql_agents.process_batch.is_text") +@patch("sql_agents.process_batch.SqlAgents.create", new_callable=AsyncMock) +@patch("sql_agents.process_batch.AzureAIAgent.create_client") +@patch("sql_agents.process_batch.DefaultAzureCredential") +@patch("sql_agents.process_batch.BatchService") +@patch("sql_agents.process_batch.BlobStorageFactory.get_storage", new_callable=AsyncMock) +async def test_process_batch_service_response_exception( + mock_get_storage, + mock_batch_service_cls, + mock_creds_cls, + mock_create_client, + mock_sql_agents_create, + mock_is_text, + mock_process_error, +): + file_id = str(UUID(int=3)) + now = datetime.datetime.utcnow() + file_dict = { + "file_id": file_id, + "blob_path": "some/blob.sql", + "batch_id": "3", + "original_name": "serviceerror.sql", + "translated_path": "translated.sql", + "status": ProcessStatus.READY_TO_PROCESS.value, + "error_count": 0, + "syntax_count": 0, + "created_at": now, + "updated_at": now, + } + + # Trigger ServiceResponseException + mock_storage = AsyncMock() + mock_storage.get_file.side_effect = ServiceResponseException("Service error") + mock_get_storage.return_value = mock_storage + + mock_batch_service = AsyncMock() + mock_batch_service.database.get_batch_files.return_value = [file_dict] + mock_batch_service_cls.return_value = mock_batch_service + + mock_creds_cls.return_value.__aenter__.return_value = MagicMock() + mock_create_client.return_value.__aenter__.return_value = MagicMock() + mock_sql_agents_create.return_value = AsyncMock() + + mock_file_record = MagicMock() + mock_file_record.file_id = file_id + mock_file_record.batch_id = "3" + mock_file_record.blob_path = "some/blob.sql" + + with patch("sql_agents.process_batch.FileRecord.fromdb", return_value=mock_file_record): + await process_batch_async("3", "informix", "tsql") + + mock_process_error.assert_called_once() + + +@pytest.mark.asyncio +@patch("sql_agents.process_batch.process_error", new_callable=AsyncMock) +@patch("sql_agents.process_batch.is_text") +@patch("sql_agents.process_batch.SqlAgents.create", new_callable=AsyncMock) +@patch("sql_agents.process_batch.AzureAIAgent.create_client") +@patch("sql_agents.process_batch.DefaultAzureCredential") +@patch("sql_agents.process_batch.BatchService") +@patch("sql_agents.process_batch.BlobStorageFactory.get_storage", new_callable=AsyncMock) +async def test_process_batch_generic_exception( + mock_get_storage, + mock_batch_service_cls, + mock_creds_cls, + mock_create_client, + mock_sql_agents_create, + mock_is_text, + mock_process_error, +): + file_id = str(UUID(int=4)) + now = datetime.datetime.utcnow() + file_dict = { + "file_id": file_id, + "blob_path": "generic/blob.sql", + "batch_id": "4", + "original_name": "fail.sql", + "translated_path": "translated.sql", + "status": ProcessStatus.READY_TO_PROCESS.value, + "error_count": 0, + "syntax_count": 0, + "created_at": now, + "updated_at": now, + } + + # Trigger generic exception + mock_storage = AsyncMock() + mock_storage.get_file.side_effect = Exception("Unexpected failure") + mock_get_storage.return_value = mock_storage + + mock_batch_service = AsyncMock() + mock_batch_service.database.get_batch_files.return_value = [file_dict] + mock_batch_service_cls.return_value = mock_batch_service + + mock_creds_cls.return_value.__aenter__.return_value = MagicMock() + mock_create_client.return_value.__aenter__.return_value = MagicMock() + mock_sql_agents_create.return_value = AsyncMock() + + mock_file_record = MagicMock() + mock_file_record.file_id = file_id + mock_file_record.batch_id = "4" + mock_file_record.blob_path = "generic/blob.sql" + + with patch("sql_agents.process_batch.FileRecord.fromdb", return_value=mock_file_record): + await process_batch_async("4", "informix", "tsql") + + mock_process_error.assert_called_once() + + +@pytest.mark.asyncio +@patch("sql_agents.process_batch.send_status_update") +async def test_process_error(mock_send_status_update): + # Setup complete FileRecord with required fields + file_record = FileRecord( + file_id="1", + batch_id="b1", + blob_path="blobpath", + original_name="file.sql", + translated_path="translated.sql", + status=ProcessStatus.READY_TO_PROCESS, + error_count=0, + syntax_count=0, + created_at=datetime.datetime.utcnow(), + updated_at=datetime.datetime.utcnow(), + ) + + batch_service = AsyncMock() + + await process_error(ValueError("Test error"), file_record, batch_service) + + batch_service.create_file_log.assert_awaited_once() + mock_send_status_update.assert_called_once() + assert mock_send_status_update.call_args[1]["status"].file_result == FileResult.ERROR + + +def test_add_rai_disclaimer(): + original = "SELECT * FROM test;" + result = add_rai_disclaimer(original) + assert result.startswith("/*") + assert original in result