diff --git a/modules/fidesModule/persistence/sqlite_db.py b/modules/fidesModule/persistence/sqlite_db.py index 8fc48b7c1..e5e155558 100644 --- a/modules/fidesModule/persistence/sqlite_db.py +++ b/modules/fidesModule/persistence/sqlite_db.py @@ -3,6 +3,7 @@ Python has None, SQLite has NULL, conversion is automatic in both ways. """ + import os import sqlite3 from typing import List, Any, Optional @@ -38,7 +39,7 @@ def __init__(self, logger: Output, db_path: str) -> None: self.__create_tables() def __slips_log(self, txt: str) -> None: - self.logger.output_line( + self.logger.output_line_to_cli_and_logfiles( {"verbose": 2, "debug": 0, "from": self.name, "txt": txt} ) @@ -468,7 +469,9 @@ def __connect(self) -> None: Establishes a connection to the SQLite database. """ self.__slips_log(f"Connecting to SQLite database at {self.db_path}") - self.connection = sqlite3.connect(self.db_path, check_same_thread=False) + self.connection = sqlite3.connect( + self.db_path, check_same_thread=False + ) if self.connection is None: self.__slips_log("Failed to connect to the SQLite database!") diff --git a/modules/flowalerts/flowalerts.py b/modules/flowalerts/flowalerts.py index a671e4e20..dc132058d 100644 --- a/modules/flowalerts/flowalerts.py +++ b/modules/flowalerts/flowalerts.py @@ -61,7 +61,6 @@ def subscribe_to_channels(self): async def shutdown_gracefully(self): self.dns.shutdown_gracefully() - await asyncio.gather(*self.tasks, return_exceptions=True) def pre_main(self): utils.drop_root_privs() diff --git a/slips_files/common/abstracts/async_module.py b/slips_files/common/abstracts/async_module.py index 20510a4ce..6bd1d1e2c 100644 --- a/slips_files/common/abstracts/async_module.py +++ b/slips_files/common/abstracts/async_module.py @@ -35,7 +35,7 @@ def create_task(self, func, *args) -> Task: # Allow the event loop to run the scheduled task # await asyncio.sleep(0) - # to wait for these functions before flowalerts shuts down + # to wait for these functions before this module shuts down self.tasks.append(task) return task @@ -58,9 +58,16 @@ def handle_exception(self, task): async def main(self): ... async def shutdown_gracefully(self): - """Implement the async shutdown logic here""" + """ + Implement the async shutdown logic here + """ pass + async def gather_tasks_and_shutdown_gracefully(self): + """Implement the async shutdown logic here""" + await asyncio.gather(*self.tasks, return_exceptions=True) + await self.shutdown_gracefully() + def run_async_function(self, func: Callable): """ If the func argument is a coroutine object it is implicitly @@ -78,14 +85,18 @@ def run(self): try: error: bool = self.pre_main() if error or self.should_stop(): - self.run_async_function(self.shutdown_gracefully) + self.run_async_function( + self.gather_tasks_and_shutdown_gracefully + ) return except KeyboardInterrupt: - self.run_async_function(self.shutdown_gracefully) + self.run_async_function(self.gather_tasks_and_shutdown_gracefully) return except RuntimeError as e: if "Event loop stopped before Future completed" in str(e): - self.run_async_function(self.shutdown_gracefully) + self.run_async_function( + self.gather_tasks_and_shutdown_gracefully + ) return except Exception: self.print_traceback() @@ -94,14 +105,18 @@ def run(self): while True: try: if self.should_stop(): - self.run_async_function(self.shutdown_gracefully) + self.run_async_function( + self.gather_tasks_and_shutdown_gracefully + ) return # if a module's main() returns 1, it means there's an # error and it needs to stop immediately error: bool = self.run_async_function(self.main) if error: - self.run_async_function(self.shutdown_gracefully) + self.run_async_function( + self.gather_tasks_and_shutdown_gracefully + ) return except KeyboardInterrupt: @@ -114,7 +129,9 @@ def run(self): continue except RuntimeError as e: if "Event loop stopped before Future completed" in str(e): - self.run_async_function(self.shutdown_gracefully) + self.run_async_function( + self.gather_tasks_and_shutdown_gracefully + ) return except Exception: self.print_traceback() diff --git a/slips_files/core/output.py b/slips_files/core/output.py index 984079a43..5ec6567b3 100644 --- a/slips_files/core/output.py +++ b/slips_files/core/output.py @@ -176,7 +176,7 @@ def enough_debug(self, debug: int): """ return 0 < debug <= 3 and debug <= self.debug - def output_line(self, msg: dict): + def output_line_to_cli_and_logfiles(self, msg: dict): """ Prints to terminal and logfiles depending on the debug and verbose levels @@ -225,5 +225,4 @@ def update(self, msg: dict): if msg.get("log_to_logfiles_only", False): self.log_line(msg) else: - # output to terminal - self.output_line(msg) + self.output_line_to_cli_and_logfiles(msg) diff --git a/tests/test_output.py b/tests/test_output.py index ea205ed50..185a45e20 100644 --- a/tests/test_output.py +++ b/tests/test_output.py @@ -117,7 +117,7 @@ def test_output_line_all_outputs(mock_log_error, mock_log_line, mock_print): "debug": 1, } - output.output_line(msg) + output.output_line_to_cli_and_logfiles(msg) mock_print.assert_called_with(msg["from"], msg["txt"], end="\n") mock_log_line.assert_called_with(msg) @@ -143,7 +143,7 @@ def test_output_line_no_outputs(mock_log_error, mock_log_line, mock_print): "debug": 0, } - output.output_line(msg) + output.output_line_to_cli_and_logfiles(msg) mock_print.assert_not_called() mock_log_line.assert_not_called() @@ -165,7 +165,7 @@ def test_output_line_no_error_log(mock_log_error, mock_log_line, mock_print): "debug": 2, } - output.output_line(msg) + output.output_line_to_cli_and_logfiles(msg) mock_print.assert_called_with(msg["from"], msg["txt"], end="\n") mock_log_line.assert_called_with(msg) @@ -189,13 +189,15 @@ def test_update(msg, expected_output_line_calls): """Test that the update method handles different cases correctly.""" output = ModuleFactory().create_output_obj() - output.output_line = MagicMock() + output.output_line_to_cli_and_logfiles = MagicMock() output.update(msg) - assert output.output_line.call_count == len(expected_output_line_calls) + assert output.output_line_to_cli_and_logfiles.call_count == len( + expected_output_line_calls + ) for call in expected_output_line_calls: - output.output_line.assert_any_call(call) + output.output_line_to_cli_and_logfiles.assert_any_call(call) def test_update_log_to_logfiles_only():