Skip to content

gather tasks before exiting asyncmodule #1377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions modules/fidesModule/persistence/sqlite_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

Python has None, SQLite has NULL, conversion is automatic in both ways.
"""

import os
import sqlite3
from typing import List, Any, Optional
Expand Down Expand Up @@ -38,7 +39,7 @@ def __init__(self, logger: Output, db_path: str) -> None:
self.__create_tables()

def __slips_log(self, txt: str) -> None:
self.logger.output_line(
self.logger.output_line_to_cli_and_logfiles(
{"verbose": 2, "debug": 0, "from": self.name, "txt": txt}
)

Expand Down Expand Up @@ -468,7 +469,9 @@ def __connect(self) -> None:
Establishes a connection to the SQLite database.
"""
self.__slips_log(f"Connecting to SQLite database at {self.db_path}")
self.connection = sqlite3.connect(self.db_path, check_same_thread=False)
self.connection = sqlite3.connect(
self.db_path, check_same_thread=False
)

if self.connection is None:
self.__slips_log("Failed to connect to the SQLite database!")
Expand Down
1 change: 0 additions & 1 deletion modules/flowalerts/flowalerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def subscribe_to_channels(self):

async def shutdown_gracefully(self):
self.dns.shutdown_gracefully()
await asyncio.gather(*self.tasks, return_exceptions=True)

def pre_main(self):
utils.drop_root_privs()
Expand Down
33 changes: 25 additions & 8 deletions slips_files/common/abstracts/async_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def create_task(self, func, *args) -> Task:
# Allow the event loop to run the scheduled task
# await asyncio.sleep(0)

# to wait for these functions before flowalerts shuts down
# to wait for these functions before this module shuts down
self.tasks.append(task)
return task

Expand All @@ -58,9 +58,16 @@ def handle_exception(self, task):
async def main(self): ...

async def shutdown_gracefully(self):
"""Implement the async shutdown logic here"""
"""
Implement the async shutdown logic here
"""
pass

async def gather_tasks_and_shutdown_gracefully(self):
"""Implement the async shutdown logic here"""
await asyncio.gather(*self.tasks, return_exceptions=True)
await self.shutdown_gracefully()

def run_async_function(self, func: Callable):
"""
If the func argument is a coroutine object it is implicitly
Expand All @@ -78,14 +85,18 @@ def run(self):
try:
error: bool = self.pre_main()
if error or self.should_stop():
self.run_async_function(self.shutdown_gracefully)
self.run_async_function(
self.gather_tasks_and_shutdown_gracefully
)
return
except KeyboardInterrupt:
self.run_async_function(self.shutdown_gracefully)
self.run_async_function(self.gather_tasks_and_shutdown_gracefully)
return
except RuntimeError as e:
if "Event loop stopped before Future completed" in str(e):
self.run_async_function(self.shutdown_gracefully)
self.run_async_function(
self.gather_tasks_and_shutdown_gracefully
)
return
except Exception:
self.print_traceback()
Expand All @@ -94,14 +105,18 @@ def run(self):
while True:
try:
if self.should_stop():
self.run_async_function(self.shutdown_gracefully)
self.run_async_function(
self.gather_tasks_and_shutdown_gracefully
)
return

# if a module's main() returns 1, it means there's an
# error and it needs to stop immediately
error: bool = self.run_async_function(self.main)
if error:
self.run_async_function(self.shutdown_gracefully)
self.run_async_function(
self.gather_tasks_and_shutdown_gracefully
)
return

except KeyboardInterrupt:
Expand All @@ -114,7 +129,9 @@ def run(self):
continue
except RuntimeError as e:
if "Event loop stopped before Future completed" in str(e):
self.run_async_function(self.shutdown_gracefully)
self.run_async_function(
self.gather_tasks_and_shutdown_gracefully
)
return
except Exception:
self.print_traceback()
Expand Down
5 changes: 2 additions & 3 deletions slips_files/core/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def enough_debug(self, debug: int):
"""
return 0 < debug <= 3 and debug <= self.debug

def output_line(self, msg: dict):
def output_line_to_cli_and_logfiles(self, msg: dict):
"""
Prints to terminal and logfiles depending on the debug and verbose
levels
Expand Down Expand Up @@ -225,5 +225,4 @@ def update(self, msg: dict):
if msg.get("log_to_logfiles_only", False):
self.log_line(msg)
else:
# output to terminal
self.output_line(msg)
self.output_line_to_cli_and_logfiles(msg)
14 changes: 8 additions & 6 deletions tests/test_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_output_line_all_outputs(mock_log_error, mock_log_line, mock_print):
"debug": 1,
}

output.output_line(msg)
output.output_line_to_cli_and_logfiles(msg)

mock_print.assert_called_with(msg["from"], msg["txt"], end="\n")
mock_log_line.assert_called_with(msg)
Expand All @@ -143,7 +143,7 @@ def test_output_line_no_outputs(mock_log_error, mock_log_line, mock_print):
"debug": 0,
}

output.output_line(msg)
output.output_line_to_cli_and_logfiles(msg)

mock_print.assert_not_called()
mock_log_line.assert_not_called()
Expand All @@ -165,7 +165,7 @@ def test_output_line_no_error_log(mock_log_error, mock_log_line, mock_print):
"debug": 2,
}

output.output_line(msg)
output.output_line_to_cli_and_logfiles(msg)

mock_print.assert_called_with(msg["from"], msg["txt"], end="\n")
mock_log_line.assert_called_with(msg)
Expand All @@ -189,13 +189,15 @@ def test_update(msg, expected_output_line_calls):
"""Test that the update method handles
different cases correctly."""
output = ModuleFactory().create_output_obj()
output.output_line = MagicMock()
output.output_line_to_cli_and_logfiles = MagicMock()

output.update(msg)

assert output.output_line.call_count == len(expected_output_line_calls)
assert output.output_line_to_cli_and_logfiles.call_count == len(
expected_output_line_calls
)
for call in expected_output_line_calls:
output.output_line.assert_any_call(call)
output.output_line_to_cli_and_logfiles.assert_any_call(call)


def test_update_log_to_logfiles_only():
Expand Down
Loading