diff --git a/.github/release-drafter.yml b/.github/release-drafter.yml index bd281b5c..c345ea1a 100644 --- a/.github/release-drafter.yml +++ b/.github/release-drafter.yml @@ -7,8 +7,11 @@ version-template: "2023.10.06.$PATCH" template: | #### What's Changed $CHANGES + footer: | + Built to help reduce copy/paste from multiple projects and uses calendar versioning (year.month.day.build) from [BumpCalver](https://github.com/devsetgo/bumpcalver). + categories: - title: 'Breaking' label: 'type: breaking' diff --git a/coverage-badge.svg b/coverage-badge.svg index 073992bf..a00822ea 100644 --- a/coverage-badge.svg +++ b/coverage-badge.svg @@ -1 +1 @@ -coverage: 99.16%coverage99.16% +coverage: 100.00%coverage100.00% diff --git a/coverage.xml b/coverage.xml index 651f64ce..9937d12b 100644 --- a/coverage.xml +++ b/coverage.xml @@ -1,6 +1,6 @@ - - + + /github/workspace @@ -25,7 +25,7 @@ - + @@ -132,7 +132,7 @@ - + @@ -214,155 +214,185 @@ + + + + + + + + + + + + + + + + + + + + + + - + - + + + - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - - - - - - - + + + + + + + + + + - - - - + + + + - - - - - - - - - - - - + + + + + + + + + + + + - - + + - - - - - - - - - - - - - + + + + + + + + + + + - - - - + + - + + + + - - - - - - - - - + + + + + + + - - + + - + + - - - - - - - - - + - + @@ -446,7 +476,7 @@ - + @@ -482,7 +512,7 @@ - + diff --git a/dsg_lib/__init__.py b/dsg_lib/__init__.py index 8a738f50..2e634f7d 100644 --- a/dsg_lib/__init__.py +++ b/dsg_lib/__init__.py @@ -8,7 +8,7 @@ """ from datetime import date -__version__ = "2025.04.05-001" +__version__ = "2025.04.17-001" __author__ = "Mike Ryan" __license__ = "MIT" __copyright__ = f"Copyright© 2021-{date.today().year}" diff --git a/dsg_lib/async_database_functions/database_operations.py b/dsg_lib/async_database_functions/database_operations.py index 29bc8fd1..34dbe143 100644 --- a/dsg_lib/async_database_functions/database_operations.py +++ b/dsg_lib/async_database_functions/database_operations.py @@ -623,48 +623,6 @@ async def read_one_record(self, query): return handle_exceptions(ex) # pragma: no cover async def read_query(self, query): - """ - Executes a fetch query on the database and returns a list of records - that match the query. - - This asynchronous method accepts a SQLAlchemy `Select` query object. - It returns a list of records that match the query. - - Parameters: - query (Select): A SQLAlchemy `Select` query object specifying the - conditions to fetch records for. - - Returns: - list: A list of records that match the query. - - Raises: - Exception: If any error occurs during the execution of the query. - - Example: - ```python - from dsg_lib.async_database_functions import ( - async_database, - base_schema, - database_config, - database_operations, - ) - # Create a DBConfig instance - config = { - "database_uri": "sqlite+aiosqlite:///:memory:?cache=shared", - "echo": False, - "future": True, - "pool_recycle": 3600, - } - # create database configuration - db_config = database_config.DBConfig(config) - # Create an AsyncDatabase instance - async_db = async_database.AsyncDatabase(db_config) - # Create a DatabaseOperations instance - db_ops = database_operations.DatabaseOperations(async_db) - # read query - records = await db_ops.read_query(select(User).where(User.age > 30)) - ``` - """ # Log the start of the operation logger.debug("Starting read_query operation") @@ -676,82 +634,83 @@ async def read_query(self, query): # Execute the fetch query and retrieve the records result = await session.execute(query) - records = result.scalars().all() - logger.debug(f"read_query result: {records}") - # Log the successful query execution - if all( - isinstance(record, tuple) for record in records - ): # pragma: no cover - logger.debug(f"read_query result is a tuple {type(records)}") - # If all records are tuples, convert them to dictionaries - records_data = [ - dict(zip(("request_group_id", "count"), record, strict=False)) - for record in records - ] - else: - logger.debug(f"read_query result is a dictionary {type(records)}") - # Otherwise, try to convert the records to dictionaries using the __dict__ attribute - records_data = [record.__dict__ for record in records] - - logger.debug( - f"Fetch query executed successfully. Records: {records_data}" - ) + # Use result.keys() to determine number of columns in result + if hasattr(result, "keys") and callable(result.keys): + keys = result.keys() + if len(keys) == 1: + # Use scalars() for single-column queries + records = result.scalars().all() + else: + rows = result.fetchall() + records = [] + for row in rows: + if hasattr(row, "_mapping"): + mapping = row._mapping + if len(mapping) == 1:# pragma: no cover + records.append(list(mapping.values())[0])# pragma: no cover + else: + records.append(dict(mapping)) + elif hasattr(row, "__dict__"):# pragma: no cover + records.append(row)# pragma: no cover + else:# pragma: no cover + records.append(row)# pragma: no cover + else: + # Fallback to previous logic if keys() is not available + rows = result.fetchall() + records = [] + for row in rows: + if hasattr(row, "_mapping"): + mapping = row._mapping + if len(mapping) == 1: + records.append(list(mapping.values())[0]) + else:# pragma: no cover + records.append(dict(mapping))# pragma: no cover + elif hasattr(row, "__dict__"): + records.append(row) + else: + records.append(row)# pragma: no cover + logger.debug(f"read_query result: {records}") return records except Exception as ex: - # Handle any exceptions that occur during the query execution logger.error(f"Exception occurred: {ex}") return handle_exceptions(ex) async def read_multi_query(self, queries: Dict[str, str]): """ - Executes multiple fetch queries on the database and returns a dictionary - of results for each query. + Executes multiple fetch queries asynchronously and returns a dictionary of results for each query. - This asynchronous method takes a dictionary where each key is a query - name and each value is a SQLAlchemy `Select` query object. The method executes each - query and returns a dictionary where each key is the query name, and the - corresponding value is a list of records that match that query. + This asynchronous method accepts a dictionary where each key is a query name (str) + and each value is a SQLAlchemy `Select` query object. It executes each query within a single + database session and collects the results. The results are returned as a dictionary mapping + each query name to a list of records that match that query. - Parameters: - queries (Dict[str, Select]): A dictionary of SQLAlchemy `Select` - query objects. + The function automatically determines the structure of each result set: + - If the query returns a single column, the result will be a list of scalar values. + - If the query returns multiple columns, the result will be a list of dictionaries mapping column names to values. + - If the result row is an ORM object, it will be returned as-is. + + Args: + queries (Dict[str, Select]): A dictionary mapping query names to SQLAlchemy `Select` query objects. Returns: - dict: A dictionary where each key is a query name and each value is - a list of records that match the query. + Dict[str, List[Any]]: A dictionary where each key is a query name and each value is a list of records + (scalars, dictionaries, or ORM objects) that match the corresponding query. Raises: - Exception: If any error occurs during the execution of the queries. + Exception: If any error occurs during the execution of any query, the function logs the error and + returns a dictionary with error details using `handle_exceptions`. Example: ```python - from dsg_lib.async_database_functions import ( - async_database, - base_schema, - database_config, - database_operations, - ) - # Create a DBConfig instance - config = { - "database_uri": "sqlite+aiosqlite:///:memory:?cache=shared", - "echo": False, - "future": True, - "pool_recycle": 3600, - } - # create database configuration - db_config = database_config.DBConfig(config) - # Create an AsyncDatabase instance - async_db = async_database.AsyncDatabase(db_config) - # Create a DatabaseOperations instance - db_ops = database_operations.DatabaseOperations(async_db) - # read multi query + from sqlalchemy import select queries = { - "query1": select(User).where(User.age > 30), - "query2": select(User).where(User.age < 20), + "adults": select(User).where(User.age >= 18), + "minors": select(User).where(User.age < 18), } results = await db_ops.read_multi_query(queries) + # results["adults"] and results["minors"] will contain lists of records ``` """ # Log the start of the operation @@ -759,26 +718,46 @@ async def read_multi_query(self, queries: Dict[str, str]): try: results = {} - # Start a new database session async with self.async_db.get_db_session() as session: for query_name, query in queries.items(): - # Log the query being executed logger.debug(f"Executing fetch query: {query}") - - # Execute the fetch query and retrieve the records result = await session.execute(query) - data = result.scalars().all() - - # Convert the records to dictionaries for logging - data_dicts = [record.__dict__ for record in data] - logger.debug(f"Fetch result for query '{query_name}': {data_dicts}") - - # Store the records in the results dictionary + if hasattr(result, "keys") and callable(result.keys): + keys = result.keys() + if len(keys) == 1: + data = result.scalars().all() + else: + rows = result.fetchall() + data = [] + for row in rows: + if hasattr(row, "_mapping"): + mapping = row._mapping + if len(mapping) == 1: # pragma: no cover + data.append(list(mapping.values())[0]) # pragma: no cover + else: + data.append(dict(mapping)) + elif hasattr(row, "__dict__"): # pragma: no cover + data.append(row) # pragma: no cover + else:# pragma: no cover + data.append(row)# pragma: no cover + else: + rows = result.fetchall() + data = [] + for row in rows: + if hasattr(row, "_mapping"): + mapping = row._mapping + if len(mapping) == 1:# pragma: no cover + data.append(list(mapping.values())[0])# pragma: no cover + else: + data.append(dict(mapping)) + elif hasattr(row, "__dict__"): + data.append(row) + else: + data.append(row)# pragma: no cover results[query_name] = data return results except Exception as ex: - # Handle any exceptions that occur during the query execution logger.error(f"Exception occurred: {ex}") return handle_exceptions(ex) diff --git a/dsg_lib/fastapi_functions/system_health_endpoints.py b/dsg_lib/fastapi_functions/system_health_endpoints.py index f93b78df..5c52a2a5 100644 --- a/dsg_lib/fastapi_functions/system_health_endpoints.py +++ b/dsg_lib/fastapi_functions/system_health_endpoints.py @@ -148,7 +148,7 @@ def create_health_router(config: dict): from fastapi.responses import ORJSONResponse except ImportError: # pragma: no cover APIRouter = HTTPException = status = ORJSONResponse = fastapi = ( - None # pragma: no cover + None ) # Check FastAPI version diff --git a/examples/fastapi_example.py b/examples/fastapi_example.py index 1ed9cdf6..cab7939c 100644 --- a/examples/fastapi_example.py +++ b/examples/fastapi_example.py @@ -449,6 +449,36 @@ async def read_list_of_records( return records_list +@app.get("/database/get-list-of-distinct-records", tags=["Database Examples"]) +async def read_list_of_distinct_records(): + + # create many similar records to test distinct + queries = [] + for i in tqdm(range(100), desc="executing many fake users"): + value = f"Agent {i}" + queries.append( + ( + insert(User), + { + "first_name": value, + "last_name": "Smith", + "email": f"{value.lower()}@abc.com", + }, + ) + ) + + results = await db_ops.execute_many(queries) + print(results) + + distinct_last_name_query = Select(User.last_name).distinct() + logger.info(f"Executing query: {distinct_last_name_query}") + records = await db_ops.read_query(query=distinct_last_name_query) + + + logger.info(f"Read list of distinct records: {records}") + return records + + @app.post("/database/execute-one", tags=["Database Examples"]) async def execute_query(query: str = Body(...)): # add a user with execute_one @@ -481,6 +511,7 @@ async def execute_many(query: str = Body(...)): return query_return + if __name__ == "__main__": import uvicorn diff --git a/makefile b/makefile index 8a0cd4b9..f8469681 100644 --- a/makefile +++ b/makefile @@ -1,6 +1,6 @@ # Variables REPONAME = devsetgo_lib -APP_VERSION = 2025.04.05-001 +APP_VERSION = 2025.04.17-001 PYTHON = python3 PIP = $(PYTHON) -m pip PYTEST = $(PYTHON) -m pytest diff --git a/pyproject.toml b/pyproject.toml index 04b31508..544fb97f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "devsetgo_lib" -version = "2025.4.5.1" +version = "2025.4.17.1" requires-python = ">=3.9" description = "DevSetGo Library is a Python library offering reusable functions for efficient coding. It includes file operations, calendar utilities, pattern matching, advanced logging with loguru, FastAPI endpoints, async database handling, and email validation. Designed for ease of use and versatility, it's a valuable tool for Python developers.\n" keywords = [ "python", "library", "reusable functions", "file operations", "calendar utilities", "pattern matching", "logging", "loguru", "FastAPI", "async database", "CRUD operations", "email validation", "development tools",] diff --git a/report.xml b/report.xml index 35252c61..5d3d954b 100644 --- a/report.xml +++ b/report.xml @@ -1 +1 @@ - + diff --git a/tests-badge.svg b/tests-badge.svg index 90005645..45addb25 100644 --- a/tests-badge.svg +++ b/tests-badge.svg @@ -1 +1 @@ -tests: 134tests134 +tests: 149tests149 diff --git a/tests/test_database_functions/test_async_database.py b/tests/test_database_functions/test_async_database.py index 99c143d5..cf402132 100644 --- a/tests/test_database_functions/test_async_database.py +++ b/tests/test_database_functions/test_async_database.py @@ -26,6 +26,7 @@ class User(async_db.Base): __tablename__ = "users" pkid = Column(Integer, primary_key=True) name = Column(String, unique=True) + color = Column(String) # New column, not unique @pytest_asyncio.fixture(scope="class", autouse=True) @@ -104,6 +105,60 @@ async def test_read_query(self, db_ops): assert isinstance(data, list) assert len(data) > 0 + @pytest.mark.asyncio + async def test_read_query_distinct(self, db_ops): + # Insert users with unique names but duplicate colors + queries = [ + (insert(User), {'name': 'Alice', 'color': 'red'}), + (insert(User), {'name': 'Bob', 'color': 'blue'}), + (insert(User), {'name': 'Charlie', 'color': 'red'}), # Duplicate color + ] + await db_ops.execute_many(queries) + + # Debug: verify users are inserted + inserted = await db_ops.read_query(select(User)) + print("Inserted users:", [(u.name, u.color) for u in inserted]) + + # Distinct on color column + query = select(User.color).distinct() + result = await db_ops.read_query(query) + print("Distinct colors:", result) + # If result is a list of tuples, flatten it + if result and isinstance(result[0], tuple): + result = [r[0] for r in result] + assert set(result) == {"red", "blue"} + # Distinct on full row (should return unique rows) + query = select(User).distinct() + result = await db_ops.read_query(query) + assert all(isinstance(u, User) for u in result) + assert len(result) == 3 + + @pytest.mark.asyncio + async def test_read_query_tuple_result(self, db_ops): + # Insert users with different colors + queries = [ + (insert(User), {'name': 'Alice', 'color': 'red'}), + (insert(User), {'name': 'Bob', 'color': 'blue'}), + (insert(User), {'name': 'Charlie', 'color': 'red'}), + ] + await db_ops.execute_many(queries) + # Query for color and name as a tuple + query = select(User.color, User.name).distinct() + result = await db_ops.read_query(query) + # Should be a list of dicts or tuples, depending on implementation + # If dicts, convert to tuples for comparison + processed = [] + for r in result: + if isinstance(r, dict): + processed.append((r.get("color"), r.get("name"))) + elif isinstance(r, tuple): + processed.append(r) + else: + # fallback for unexpected types + processed.append(tuple(r)) + expected = {('red', 'Alice'), ('blue', 'Bob'), ('red', 'Charlie')} + assert set(processed) == expected + @pytest.mark.asyncio async def test_read_query_sqlalchemy_error(self, db_ops, mocker): # Mock the get_db_session method to raise an SQLAlchemyError @@ -130,6 +185,216 @@ async def test_read_query_general_exception(self, db_ops, mocker): result = await db_ops.read_query(select(User)) assert result == {"error": "General Exception", "details": "Test error message"} + @pytest.mark.asyncio + async def test_read_query_mapping_single_and_multi(self, db_ops, mocker): + # Insert a user for single-column mapping + await db_ops.execute_one(insert(User).values(name="SingleMapping", color="green")) + # Patch the result to simulate _mapping with single and multi keys + class FakeRow: + def __init__(self, mapping): + self._mapping = mapping + + # Simulate single-column mapping + fake_rows = [FakeRow({"color": "green"})] + fake_result = mocker.Mock() + fake_result.keys = lambda: ["color"] + fake_result.fetchall.return_value = fake_rows + # Patch scalars().all() to return the expected list + fake_result.scalars.return_value.all.return_value = ["green"] + + class FakeSession: + async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): pass + async def execute(self, query): return fake_result + + mocker.patch.object(db_ops.async_db, "get_db_session", return_value=FakeSession()) + + # Should hit the mapping == 1 branch (scalars().all()) + result = await db_ops.read_query(select(User.color)) + assert result == ["green"] + + # Simulate multi-column mapping + fake_rows = [FakeRow({"color": "green", "name": "SingleMapping"})] + fake_result.fetchall.return_value = fake_rows + fake_result.keys = lambda: ["color", "name"] + # Patch scalars().all() to not be used for multi-column + fake_result.scalars.return_value.all.return_value = None + + result = await db_ops.read_query(select(User.color, User.name)) + assert result == [{"color": "green", "name": "SingleMapping"}] + + @pytest.mark.asyncio + async def test_read_query_mapping_fallback_branches(self, db_ops, mocker): + # Patch the result to simulate fallback logic (no keys method) + class FakeRow: + def __init__(self, mapping): + self._mapping = mapping + + fake_rows = [FakeRow({"color": "blue"})] + fake_result = mocker.Mock() + del fake_result.keys # Remove keys method + fake_result.fetchall.return_value = fake_rows + + class FakeSession: + async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): pass + async def execute(self, query): return fake_result + + mocker.patch.object(db_ops.async_db, "get_db_session", return_value=FakeSession()) + + result = await db_ops.read_query("fake_query") + assert result == ["blue"] + + # Fallback branch: row without _mapping but with __dict__ + class RowWithDict: + def __init__(self): + self.__dict__ = {"foo": "bar"} + fake_rows = [RowWithDict()] + fake_result.fetchall.return_value = fake_rows + + result = await db_ops.read_query("fake_query") + assert result == fake_rows + + # Fallback branch: row without _mapping or __dict__ + class RowPlain: + pass + fake_rows = [RowPlain()] + fake_result.fetchall.return_value = fake_rows + + result = await db_ops.read_query("fake_query") + assert result == fake_rows + + @pytest.mark.asyncio + async def test_read_query_no_keys_no_mapping(self,db_ops, mocker): + # Simulate result object without keys() and rows without _mapping or __dict__ + class FakeRow: + pass + + class FakeResult: + def fetchall(self): + return [FakeRow(), FakeRow()] + + class FakeSession: + async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): pass + async def execute(self, query): return FakeResult() + + mocker.patch.object(db_ops.async_db, "get_db_session", return_value=FakeSession()) + result = await db_ops.read_query("fake_query") + # Should return the list of FakeRow objects + assert all(isinstance(r, FakeRow) for r in result) + + @pytest.mark.asyncio + async def test_read_multi_query_mapping_branches(self, db_ops, mocker): + # Patch the result to simulate mapping logic in read_multi_query + class FakeRow: + def __init__(self, mapping): + self._mapping = mapping + + # Single-column mapping + fake_rows = [FakeRow({"color": "red"})] + fake_result = mocker.Mock() + fake_result.keys = lambda: ["color"] + fake_result.fetchall.return_value = fake_rows + fake_result.scalars.return_value.all.return_value = ["red"] + + class FakeSession: + async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): pass + async def execute(self, query): return fake_result + + mocker.patch.object(db_ops.async_db, "get_db_session", return_value=FakeSession()) + + queries = {"single": select(User.color)} + result = await db_ops.read_multi_query(queries) + assert result == {"single": ["red"]} + + # Multi-column mapping + fake_rows = [FakeRow({"color": "red", "name": "Alice"})] + fake_result.fetchall.return_value = fake_rows + fake_result.keys = lambda: ["color", "name"] + # Patch scalars().all() to not be used for multi-column + fake_result.scalars.return_value.all.return_value = None + + queries = {"multi": select(User.color, User.name)} + result = await db_ops.read_multi_query(queries) + assert result == {"multi": [{"color": "red", "name": "Alice"}]} + + # Fallback logic (no keys method) + del fake_result.keys + fake_rows = [FakeRow({"color": "blue"})] + fake_result.fetchall.return_value = fake_rows + + queries = {"fallback": select(User.color)} + result = await db_ops.read_multi_query(queries) + assert result == {"fallback": ["blue"]} + + fake_rows = [FakeRow({"color": "blue", "name": "Fallback"})] + fake_result.fetchall.return_value = fake_rows + + queries = {"fallback_multi": select(User.color, User.name)} + result = await db_ops.read_multi_query(queries) + assert result == {"fallback_multi": [{"color": "blue", "name": "Fallback"}]} + + @pytest.mark.asyncio + async def test_read_multi_query_no_keys_no_mapping(self, db_ops, mocker): + # Simulate result object without keys() and rows without _mapping or __dict__ + class FakeRow: + pass + + class FakeResult: + def fetchall(self): + return [FakeRow(), FakeRow()] + + class FakeSession: + async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): pass + async def execute(self, query): return FakeResult() + + mocker.patch.object(db_ops.async_db, "get_db_session", return_value=FakeSession()) + queries = {"q1": "query1", "q2": "query2"} + result = await db_ops.read_multi_query(queries) + for v in result.values(): + assert all(isinstance(r, FakeRow) for r in v) + + @pytest.mark.asyncio + async def test_read_multi_query_fallback_branches(self, db_ops, mocker): + # Patch session and result to simulate fallback logic + class FakeRow: + def __init__(self, mapping): + self._mapping = mapping + + class RowWithDict: + def __init__(self): + self.__dict__ = {"foo": "bar"} + + class RowPlain: + pass + + fake_result = mocker.Mock() + # No keys method + del fake_result.keys + # Simulate three types of rows + fake_result.fetchall.side_effect = [ + [FakeRow({"color": "blue"})], # _mapping, single key + [RowWithDict()], # __dict__ + [RowPlain()] # plain + ] + + class FakeSession: + async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): pass + async def execute(self, query): + return fake_result + + mocker.patch.object(db_ops.async_db, "get_db_session", return_value=FakeSession()) + + queries = {"q1": "query1", "q2": "query2", "q3": "query3"} + result = await db_ops.read_multi_query(queries) + assert result["q1"] == ["blue"] + assert isinstance(result["q2"][0], RowWithDict) + assert isinstance(result["q3"][0], RowPlain) + @pytest.mark.asyncio async def test_read_multi_query(self, db_ops): # db_ops is already awaited by pytest, so you can use it directly @@ -421,6 +686,19 @@ async def test_read_one_record_none(self, db_ops): # assert data is none assert data is None + @pytest.mark.asyncio + async def test_read_one_record_exception(self, db_ops, mocker): + # Patch session to raise a general exception + class FakeSession: + async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): pass + async def execute(self, query): raise Exception("fail!") + + mocker.patch.object(db_ops.async_db, "get_db_session", return_value=FakeSession()) + result = await db_ops.read_one_record("fake_query") + assert isinstance(result, dict) + assert "error" in result + @pytest.mark.asyncio async def test_delete_many(self, db_ops): import secrets @@ -502,3 +780,68 @@ async def test_execute_many_delete(self): # Verify all users are deleted users = await db_ops.read_query(query=r_query) assert len(users) == 0 + + @pytest.mark.asyncio + async def test_execute_one_exception(self, db_ops, mocker): + # Patch session to raise an exception + class FakeSession: + async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): pass + async def execute(self, query, params=None): raise Exception("fail!") + + mocker.patch.object(db_ops.async_db, "get_db_session", return_value=FakeSession()) + result = await db_ops.execute_one("fake_query") + assert isinstance(result, dict) + assert "error" in result + + @pytest.mark.asyncio + async def test_execute_many_exception(self, db_ops, mocker): + # Patch session to raise an exception + class FakeSession: + async def __aenter__(self): return self + async def __aexit__(self, exc_type, exc, tb): pass + async def execute(self, query, params=None): raise Exception("fail!") + + mocker.patch.object(db_ops.async_db, "get_db_session", return_value=FakeSession()) + # execute_many expects a list of queries + result = await db_ops.execute_many(["fake_query1", "fake_query2"]) + assert isinstance(result, dict) + assert "error" in result + + @pytest.mark.asyncio + async def test_get_columns_details_exception(self,db_ops, mocker): + # Patch logger and table to raise exception + class FakeTable: + __name__ = "Fake" + class __table__: + columns = property(lambda self: (_ for _ in ()).throw(Exception("fail!"))) + result = await db_ops.get_columns_details(FakeTable) + assert isinstance(result, dict) + assert "error" in result + + @pytest.mark.asyncio + async def test_get_primary_keys_exception(self,db_ops, mocker): + # Patch logger and table to raise exception + class FakeTable: + __name__ = "Fake" + class __table__: + class primary_key: + @staticmethod + def columns(): + raise Exception("fail!") + result = await db_ops.get_primary_keys(FakeTable) + assert isinstance(result, dict) + assert "error" in result + + @pytest.mark.asyncio + async def test_get_table_names_exception(self,db_ops, mocker): + # Patch async_db.Base.metadata.tables.keys to raise exception + class FakeMeta: + def keys(self): + raise Exception("fail!") + class FakeBase: + metadata = FakeMeta() + db_ops.async_db.Base = FakeBase() + result = await db_ops.get_table_names() + assert isinstance(result, dict) + assert "error" in result diff --git a/tests/test_file_functions.py b/tests/test_file_functions.py new file mode 100644 index 00000000..938ec109 --- /dev/null +++ b/tests/test_file_functions.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +""" +test_file_functions.py + +This module contains unit tests for the `file_functions` module in the +`dsg_lib.common_functions` package. + +Tests: + - test_save_json_appends_extension: Tests that the `save_json` function + appends the `.json` extension to the file name if it is missing. + +Author: Mike Ryan +Date: 2024/05/16 +License: MIT +""" + +from dsg_lib.common_functions import file_functions + + +def test_save_json_appends_extension(tmp_path): + """ + Test that the `save_json` function appends the `.json` extension to the file + name if it is missing. + """ + data = {"foo": "bar"} + file_name = "mytestfile" # No .json extension + result = file_functions.save_json(file_name, data, root_folder=str(tmp_path)) + expected_file = tmp_path / "mytestfile.json" + assert expected_file.exists() + assert result == "File saved successfully" + # Optionally, check file contents + with open(expected_file) as f: + import json + + assert json.load(f) == data