diff --git a/aikido_zen/background_process/cloud_connection_manager/__init__.py b/aikido_zen/background_process/cloud_connection_manager/__init__.py index 16f6a9bb4..acbcb06d1 100644 --- a/aikido_zen/background_process/cloud_connection_manager/__init__.py +++ b/aikido_zen/background_process/cloud_connection_manager/__init__.py @@ -10,7 +10,7 @@ from aikido_zen.storage.users import Users from aikido_zen.storage.hostnames import Hostnames from ..realtime.start_polling_for_changes import start_polling_for_changes -from ..statistics import Statistics +from ...storage.statistics import Statistics # Import functions : from .on_detected_attack import on_detected_attack @@ -45,9 +45,7 @@ def __init__(self, block, api, token, serverless): ) self.users = Users(1000) self.packages = {} - self.statistics = Statistics( - max_perf_samples_in_mem=5000, max_compressed_stats_in_mem=100 - ) + self.statistics = Statistics() self.middleware_installed = False if isinstance(serverless, str) and len(serverless) == 0: @@ -71,12 +69,8 @@ def report_initial_stats(self): This is run 1m after startup, and checks if we should send out a preliminary heartbeat with some stats. """ - data_is_available = not ( - self.statistics.is_empty() and len(self.routes.routes) <= 0 - ) - should_report_initial_stats = ( - data_is_available and not self.conf.received_any_stats - ) + data_present = not self.statistics.empty() or len(self.routes.routes) > 0 + should_report_initial_stats = data_present and not self.conf.received_any_stats if should_report_initial_stats: self.send_heartbeat() diff --git a/aikido_zen/background_process/cloud_connection_manager/init_test.py b/aikido_zen/background_process/cloud_connection_manager/init_test.py index ac43fd4fc..3c58a85af 100644 --- a/aikido_zen/background_process/cloud_connection_manager/init_test.py +++ b/aikido_zen/background_process/cloud_connection_manager/init_test.py @@ -5,7 +5,7 @@ from aikido_zen.storage.users import Users from aikido_zen.storage.hostnames import Hostnames from aikido_zen.ratelimiting.rate_limiter import RateLimiter -from aikido_zen.background_process.statistics import Statistics +from aikido_zen.storage.statistics import Statistics from . import CloudConnectionManager diff --git a/aikido_zen/background_process/cloud_connection_manager/send_heartbeat.py b/aikido_zen/background_process/cloud_connection_manager/send_heartbeat.py index 4b97a388e..409a00976 100644 --- a/aikido_zen/background_process/cloud_connection_manager/send_heartbeat.py +++ b/aikido_zen/background_process/cloud_connection_manager/send_heartbeat.py @@ -11,12 +11,12 @@ def send_heartbeat(connection_manager): if not connection_manager.token: return logger.debug("Aikido CloudConnectionManager : Sending out heartbeat") - stats = connection_manager.statistics.get_stats() + stats = connection_manager.statistics.get_record() users = connection_manager.users.as_array() routes = list(connection_manager.routes) outgoing_domains = connection_manager.hostnames.as_array() - connection_manager.statistics.reset() + connection_manager.statistics.clear() connection_manager.users.clear() connection_manager.routes.clear() connection_manager.hostnames.clear() diff --git a/aikido_zen/background_process/commands/attack.py b/aikido_zen/background_process/commands/attack.py index d275f38ee..92308e989 100644 --- a/aikido_zen/background_process/commands/attack.py +++ b/aikido_zen/background_process/commands/attack.py @@ -7,5 +7,3 @@ def process_attack(connection_manager, data, queue): Expected data object : [injection_results, context, blocked_or_not, stacktrace] """ queue.put(data) - if connection_manager and connection_manager.statistics: - connection_manager.statistics.on_detected_attack(blocked=data[2]) diff --git a/aikido_zen/background_process/commands/attack_test.py b/aikido_zen/background_process/commands/attack_test.py index 845c5f770..b5cb8e81d 100644 --- a/aikido_zen/background_process/commands/attack_test.py +++ b/aikido_zen/background_process/commands/attack_test.py @@ -20,18 +20,6 @@ def test_process_attack_adds_data_to_queue(): assert queue.get() == data -def test_process_attack_statistics_called_when_enabled(): - queue = Queue() - connection_manager = MockCloudConnectionManager() - data = ("injection_results", "context", True, "stacktrace") # Example data - process_attack(connection_manager, data, queue) - - # Check if on_detected_attack was called - connection_manager.statistics.on_detected_attack.assert_called_once_with( - blocked=True - ) - - def test_process_attack_statistics_not_called_when_disabled(): queue = Queue() connection_manager = MockCloudConnectionManager() diff --git a/aikido_zen/background_process/commands/sync_data.py b/aikido_zen/background_process/commands/sync_data.py index ebba56b31..ee453e834 100644 --- a/aikido_zen/background_process/commands/sync_data.py +++ b/aikido_zen/background_process/commands/sync_data.py @@ -8,9 +8,11 @@ def process_sync_data(connection_manager, data, conn, queue=None): """ Synchronizes data between the thread-local cache (with a TTL of usually 1 minute) and the background thread. Which data gets synced? - Thread -> BG Process : Hits, request statistics, api specs, hostnames - BG Process -> Thread : Routes, endpoints, bypasssed ip's, blocked users + Thread -> BG Process : Routes, Hostnames, Users, Stats & middleware installed + BG Process -> Thread : Routes and config """ + + # Sync routes routes = connection_manager.routes for route in data.get("current_routes", {}).values(): route_metadata = {"method": route["method"], "route": route["path"]} @@ -25,9 +27,6 @@ def process_sync_data(connection_manager, data, conn, queue=None): # Update API Spec : update_route_info(route["apispec"], existing_route) - # Save request data : - connection_manager.statistics.requests["total"] += data.get("reqs", 0) - # Save middleware installed : if data.get("middleware_installed", False): connection_manager.middleware_installed = True @@ -44,6 +43,9 @@ def process_sync_data(connection_manager, data, conn, queue=None): for user_entry in data.get("users", list()): connection_manager.users.add_user_from_entry(user_entry) + # Sync stats + connection_manager.statistics.import_from_record(data.get("stats", {})) + if connection_manager.conf.last_updated_at > 0: # Only report data if the config has been fetched. return { diff --git a/aikido_zen/background_process/commands/sync_data_test.py b/aikido_zen/background_process/commands/sync_data_test.py index c60e74246..de5958c9a 100644 --- a/aikido_zen/background_process/commands/sync_data_test.py +++ b/aikido_zen/background_process/commands/sync_data_test.py @@ -6,6 +6,7 @@ from aikido_zen.background_process.routes import Routes from aikido_zen.helpers.iplist import IPList from ...storage.hostnames import Hostnames +from ...storage.statistics import Statistics @pytest.fixture @@ -20,7 +21,7 @@ def setup_connection_manager(): connection_manager.conf.bypassed_ips.add("192.168.1.1") connection_manager.conf.blocked_uids = ["user1", "user2"] connection_manager.conf.last_updated_at = 200 - connection_manager.statistics.requests = {"total": 0} # Initialize total requests + connection_manager.statistics = Statistics() connection_manager.middleware_installed = False return connection_manager @@ -47,7 +48,18 @@ def test_process_sync_data_initialization(setup_connection_manager): "apispec": {"info": "API spec for resource"}, }, }, - "reqs": 10, # Total requests to be added + "stats": { + "startedAt": 1, + "endedAt": 1, + "requests": { + "total": 10, + "aborted": 0, + "attacksDetected": { + "total": 5, + "blocked": 0, + }, + }, + }, "middleware_installed": False, "hostnames": test_hostnames.as_array(), } @@ -70,7 +82,11 @@ def test_process_sync_data_initialization(setup_connection_manager): ) # Check that the total requests were updated - assert connection_manager.statistics.requests["total"] == 10 + assert connection_manager.statistics.get_record()["requests"] == { + "aborted": 0, + "attacksDetected": {"blocked": 0, "total": 5}, + "total": 10, + } # Check that the return value is correct assert result["routes"] == dict(connection_manager.routes.routes) @@ -101,7 +117,18 @@ def test_process_sync_data_with_last_updated_at_below_zero(setup_connection_mana "apispec": {"info": "API spec for resource"}, }, }, - "reqs": 10, # Total requests to be added + "stats": { + "startedAt": 1, + "endedAt": 1, + "requests": { + "total": 10, + "aborted": 0, + "attacksDetected": { + "total": 5, + "blocked": 0, + }, + }, + }, "middleware_installed": True, } @@ -123,7 +150,11 @@ def test_process_sync_data_with_last_updated_at_below_zero(setup_connection_mana ) # Check that the total requests were updated - assert connection_manager.statistics.requests["total"] == 10 + assert connection_manager.statistics.get_record()["requests"] == { + "aborted": 0, + "attacksDetected": {"blocked": 0, "total": 5}, + "total": 10, + } assert connection_manager.middleware_installed == True assert len(connection_manager.hostnames.as_array()) == 0 # Check that the return value is correct @@ -150,7 +181,18 @@ def test_process_sync_data_existing_route_and_hostnames(setup_connection_manager "apispec": {"info": "API spec for resource"}, } }, - "reqs": 5, # Total requests to be added + "stats": { + "startedAt": 1, + "endedAt": 1, + "requests": { + "total": 5, + "aborted": 0, + "attacksDetected": { + "total": 5, + "blocked": 0, + }, + }, + }, "hostnames": hostnames_sync.as_array(), } @@ -167,7 +209,18 @@ def test_process_sync_data_existing_route_and_hostnames(setup_connection_manager "apispec": {"info": "Updated API spec for resource"}, } }, - "reqs": 15, # Additional requests to be added + "stats": { + "startedAt": 1, + "endedAt": 1, + "requests": { + "total": 15, + "aborted": 0, + "attacksDetected": { + "total": 5, + "blocked": 0, + }, + }, + }, } result = process_sync_data(connection_manager, data_update, None) @@ -181,7 +234,11 @@ def test_process_sync_data_existing_route_and_hostnames(setup_connection_manager ) # Check that the total requests were updated - assert connection_manager.statistics.requests["total"] == 20 # 5 + 15 + assert connection_manager.statistics.get_record()["requests"] == { + "aborted": 0, + "attacksDetected": {"blocked": 0, "total": 10}, + "total": 20, + } assert connection_manager.middleware_installed == False assert connection_manager.hostnames.as_array() == [ {"hits": 215, "hostname": "example.com", "port": 443}, diff --git a/aikido_zen/background_process/statistics/__init__.py b/aikido_zen/background_process/statistics/__init__.py deleted file mode 100644 index 14273a793..000000000 --- a/aikido_zen/background_process/statistics/__init__.py +++ /dev/null @@ -1,80 +0,0 @@ -"""Exports statistics class""" - -from aikido_zen.helpers.get_current_unixtime_ms import get_unixtime_ms -from .ensure_sink_stats import ensure_sink_stats -from .compress_perf_samples import compress_perf_samples -from .on_inspected_call import on_inspected_call -from .get_stats import get_stats - - -class Statistics: - """ - Keeps track of total and aborted requests - and total and blocked attacks - """ - - def __init__(self, max_perf_samples_in_mem, max_compressed_stats_in_mem): - self.max_perf_samples_in_mem = max_perf_samples_in_mem - self.max_compressed_stats_in_mem = max_compressed_stats_in_mem - self.reset() - - def reset(self): - """Resets the stored data to an initial state""" - self.stats = {} - self.requests = { - "total": 0, - "aborted": 0, - "attacksDetected": { - "total": 0, - "blocked": 0, - }, - } - self.started_at = get_unixtime_ms() - - def has_compressed_stats(self): - """Checks if there are any compressed statistics""" - return any( - len(sink_stats["compressedTimings"]) > 0 - for sink_stats in self.stats.values() - ) - - def interceptor_threw_error(self, sink): - """Increment the error count for the interceptor for the given sink.""" - self.ensure_sink_stats(sink) - self.stats[sink]["total"] += 1 - self.stats[sink]["interceptorThrewError"] += 1 - - def on_detected_attack(self, blocked): - """Increment the attack detection statistics.""" - self.requests["attacksDetected"]["total"] += 1 - if blocked: - self.requests["attacksDetected"]["blocked"] += 1 - - def force_compress(self): - """Force compression of performance samples for all sinks.""" - for sink in self.stats: - self.compress_perf_samples(sink) - - def ensure_sink_stats(self, sink): - """Makes sure to initalize sink if it's not there""" - return ensure_sink_stats(self, sink) - - def compress_perf_samples(self, sink): - """Compress performance samples for a given sink.""" - return compress_perf_samples(self, sink) - - def on_inspected_call(self, *args, **kwargs): - """Handle an inspected call and update statistics accordingly.""" - return on_inspected_call(self, *args, **kwargs) - - def get_stats(self): - """This will return the stats as a dict, from a Statistics class""" - return get_stats(self) - - def is_empty(self): - """This will return a boolean value indicating if the stats are empty""" - return ( - len(self.stats) == 0 - and self.requests["total"] == 0 - and self.requests["attacksDetected"]["total"] == 0 - ) diff --git a/aikido_zen/background_process/statistics/compress_perf_samples.py b/aikido_zen/background_process/statistics/compress_perf_samples.py deleted file mode 100644 index 36c81b311..000000000 --- a/aikido_zen/background_process/statistics/compress_perf_samples.py +++ /dev/null @@ -1,43 +0,0 @@ -"""Exports `compress_perf_samples` function""" - -from aikido_zen.helpers.percentiles import percentiles -from aikido_zen.helpers.get_current_unixtime_ms import get_unixtime_ms - - -def compress_perf_samples(statistics_object, sink): - """Compress performance samples for a given sink.""" - # Ignore if sink stats do not exist or if there are no durations - if ( - sink not in statistics_object.stats - or not statistics_object.stats[sink]["durations"] - ): - return - - timings = statistics_object.stats[sink]["durations"] - average_in_ms = sum(timings) / len(timings) - - p50, p75, p90, p95, p99 = percentiles([50, 75, 90, 95, 99], timings) - - statistics_object.stats[sink]["compressedTimings"].append( - { - "averageInMS": average_in_ms, - "percentiles": { - "50": p50, - "75": p75, - "90": p90, - "95": p95, - "99": p99, - }, - "compressedAt": get_unixtime_ms(), - } - ) - - # Remove the oldest compressed timing if exceeding the limit - if ( - len(statistics_object.stats[sink]["compressedTimings"]) - > statistics_object.max_compressed_stats_in_mem - ): - statistics_object.stats[sink]["compressedTimings"].pop(0) - - # Clear the durations - statistics_object.stats[sink]["durations"] = [] diff --git a/aikido_zen/background_process/statistics/compress_perf_samples_test.py b/aikido_zen/background_process/statistics/compress_perf_samples_test.py deleted file mode 100644 index b6d323523..000000000 --- a/aikido_zen/background_process/statistics/compress_perf_samples_test.py +++ /dev/null @@ -1,36 +0,0 @@ -import pytest -from unittest.mock import MagicMock -from aikido_zen.helpers.percentiles import percentiles -from aikido_zen.helpers.get_current_unixtime_ms import get_unixtime_ms -from .compress_perf_samples import compress_perf_samples - - -def test_no_durations(): - """Test when there are no durations.""" - statistics_object = MagicMock() - statistics_object.stats = {"sink1": {"durations": [], "compressedTimings": []}} - compress_perf_samples(statistics_object, "sink1") - assert statistics_object.stats["sink1"]["compressedTimings"] == [] - print("test_no_durations passed") - - -def test_clear_durations_after_compression(): - """Test that durations are cleared after compression.""" - statistics_object = MagicMock() - statistics_object.stats = { - "sink1": {"durations": [100, 200, 300], "compressedTimings": []} - } - statistics_object.max_compressed_stats_in_mem = 5 - - # Mock the percentiles function - percentiles.return_value = (200, 300, 400, 450, 490) - - # Mock the get_unixtime_ms function - get_unixtime_ms.return_value = 1234567890 - - compress_perf_samples(statistics_object, "sink1") - - assert ( - statistics_object.stats["sink1"]["durations"] == [] - ) # Durations should be cleared - print("test_clear_durations_after_compression passed") diff --git a/aikido_zen/background_process/statistics/ensure_sink_stats.py b/aikido_zen/background_process/statistics/ensure_sink_stats.py deleted file mode 100644 index 7a73cacd3..000000000 --- a/aikido_zen/background_process/statistics/ensure_sink_stats.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Exports ensure_sink_stats function""" - -import copy - -EMPTY_STATS_OBJECT = { - "withoutContext": 0, - "total": 0, - "durations": [], - "compressedTimings": [], - "interceptorThrewError": 0, - "attacksDetected": { - "total": 0, - "blocked": 0, - }, -} - - -def ensure_sink_stats(statistics_obj, sink): - """Makes sure to initalize sink if it's not there""" - if sink not in statistics_obj.stats: - statistics_obj.stats[sink] = copy.deepcopy(EMPTY_STATS_OBJECT) diff --git a/aikido_zen/background_process/statistics/ensure_sink_stats_test.py b/aikido_zen/background_process/statistics/ensure_sink_stats_test.py deleted file mode 100644 index 3c113b845..000000000 --- a/aikido_zen/background_process/statistics/ensure_sink_stats_test.py +++ /dev/null @@ -1,58 +0,0 @@ -import pytest -from .ensure_sink_stats import ensure_sink_stats, EMPTY_STATS_OBJECT - - -class MockStatistics: - def __init__(self): - self.stats = {} - - -def test_ensure_sink_stats_initialization(): - statistics_obj = MockStatistics() - sink = "example_sink" - - # Ensure the sink stats are initialized - ensure_sink_stats(statistics_obj, sink) - - # Check if the sink stats are correctly initialized - assert sink in statistics_obj.stats - assert statistics_obj.stats[sink] == EMPTY_STATS_OBJECT - - -def test_ensure_sink_stats_already_initialized(): - statistics_obj = MockStatistics() - sink = "example_sink" - - # Initialize the sink stats first - ensure_sink_stats(statistics_obj, sink) - - # Ensure calling it again does not overwrite the existing stats - ensure_sink_stats(statistics_obj, sink) - - # Check if the sink stats are still the same - assert statistics_obj.stats[sink] == EMPTY_STATS_OBJECT - - -def test_ensure_sink_stats_multiple_sinks(): - statistics_obj = MockStatistics() - - # Initialize multiple sinks - ensure_sink_stats(statistics_obj, "sink_one") - ensure_sink_stats(statistics_obj, "sink_two") - - # Check if both sinks are initialized correctly - assert "sink_one" in statistics_obj.stats - assert "sink_two" in statistics_obj.stats - assert statistics_obj.stats["sink_one"] == EMPTY_STATS_OBJECT - assert statistics_obj.stats["sink_two"] == EMPTY_STATS_OBJECT - - -def test_ensure_sink_stats_empty_stats_object(): - statistics_obj = MockStatistics() - sink = "empty_sink" - - # Ensure the sink stats are initialized - ensure_sink_stats(statistics_obj, sink) - - # Check if the initialized stats object is empty - assert statistics_obj.stats[sink] == EMPTY_STATS_OBJECT diff --git a/aikido_zen/background_process/statistics/get_stats.py b/aikido_zen/background_process/statistics/get_stats.py deleted file mode 100644 index 1cc488568..000000000 --- a/aikido_zen/background_process/statistics/get_stats.py +++ /dev/null @@ -1,26 +0,0 @@ -"""Exports get_stats function""" - -import aikido_zen.helpers.get_current_unixtime_ms as t - - -def get_stats(statistics_object): - """This will return the stats as a dict, from a Statistics class""" - sinks = {} - for sink, sink_stats in statistics_object.stats.items(): - sinks[sink] = { - "total": sink_stats["total"], - "attacksDetected": { - "total": sink_stats["attacksDetected"]["total"], - "blocked": sink_stats["attacksDetected"]["blocked"], - }, - "interceptorThrewError": sink_stats["interceptorThrewError"], - "withoutContext": sink_stats["withoutContext"], - "compressedTimings": sink_stats["compressedTimings"], - } - - return { - "sinks": sinks, - "startedAt": statistics_object.started_at, - "endedAt": t.get_unixtime_ms(), - "requests": statistics_object.requests, - } diff --git a/aikido_zen/background_process/statistics/get_stats_test.py b/aikido_zen/background_process/statistics/get_stats_test.py deleted file mode 100644 index 28f8c3589..000000000 --- a/aikido_zen/background_process/statistics/get_stats_test.py +++ /dev/null @@ -1,151 +0,0 @@ -import pytest -from .get_stats import get_stats - - -class MockStatistics: - def __init__(self): - self.stats = {} - self.started_at = None - self.requests = 0 - - -def test_get_stats_single_sink(monkeypatch): - monkeypatch.setattr( - "aikido_zen.helpers.get_current_unixtime_ms.get_unixtime_ms", lambda: 2 - ) - statistics_object = MockStatistics() - statistics_object.stats = { - "sink_one": { - "total": 10, - "attacksDetected": {"total": 2, "blocked": 1}, - "interceptorThrewError": 0, - "withoutContext": 5, - "compressedTimings": [], - } - } - statistics_object.started_at = "2023-01-01T00:00:00Z" - statistics_object.requests = 30 - - expected = { - "sinks": { - "sink_one": { - "total": 10, - "attacksDetected": {"total": 2, "blocked": 1}, - "interceptorThrewError": 0, - "withoutContext": 5, - "compressedTimings": [], - } - }, - "startedAt": "2023-01-01T00:00:00Z", - "endedAt": 2, - "requests": 30, - } - - assert get_stats(statistics_object) == expected - - -def test_get_stats_multiple_sinks(monkeypatch): - monkeypatch.setattr( - "aikido_zen.helpers.get_current_unixtime_ms.get_unixtime_ms", lambda: 5 - ) - statistics_object = MockStatistics() - statistics_object.stats = { - "sink_one": { - "total": 10, - "attacksDetected": {"total": 2, "blocked": 1}, - "interceptorThrewError": 0, - "withoutContext": 5, - "compressedTimings": [], - }, - "sink_two": { - "total": 20, - "attacksDetected": {"total": 3, "blocked": 2}, - "interceptorThrewError": 1, - "withoutContext": 10, - "compressedTimings": [], - }, - } - statistics_object.started_at = "2023-01-01T00:00:00Z" - statistics_object.requests = 30 - - expected = { - "sinks": { - "sink_one": { - "total": 10, - "attacksDetected": {"total": 2, "blocked": 1}, - "interceptorThrewError": 0, - "withoutContext": 5, - "compressedTimings": [], - }, - "sink_two": { - "total": 20, - "attacksDetected": {"total": 3, "blocked": 2}, - "interceptorThrewError": 1, - "withoutContext": 10, - "compressedTimings": [], - }, - }, - "startedAt": "2023-01-01T00:00:00Z", - "endedAt": 5, - "requests": 30, - } - - assert get_stats(statistics_object) == expected - - -def test_get_stats_empty_stats(monkeypatch): - monkeypatch.setattr( - "aikido_zen.helpers.get_current_unixtime_ms.get_unixtime_ms", lambda: 1 - ) - statistics_object = MockStatistics() - statistics_object.stats = {} - statistics_object.started_at = "2023-01-01T00:00:00Z" - statistics_object.requests = 0 - - expected = { - "sinks": {}, - "startedAt": "2023-01-01T00:00:00Z", - "endedAt": 1, - "requests": 0, - } - - assert get_stats(statistics_object) == expected - - -def test_get_stats_no_started_at(monkeypatch): - monkeypatch.setattr( - "aikido_zen.helpers.get_current_unixtime_ms.get_unixtime_ms", lambda: 10 - ) - statistics_object = MockStatistics() - statistics_object.stats = { - "sink_one": { - "total": 10, - "attacksDetected": {"total": 2, "blocked": 1}, - "interceptorThrewError": 0, - "withoutContext": 5, - "compressedTimings": [], - } - } - statistics_object.started_at = None - statistics_object.requests = 30 - - expected = { - "sinks": { - "sink_one": { - "total": 10, - "attacksDetected": {"total": 2, "blocked": 1}, - "interceptorThrewError": 0, - "withoutContext": 5, - "compressedTimings": [], - } - }, - "startedAt": None, - "endedAt": 10, - "requests": 30, - } - - assert get_stats(statistics_object) == expected - - -if __name__ == "__main__": - pytest.main() diff --git a/aikido_zen/background_process/statistics/init_test.py b/aikido_zen/background_process/statistics/init_test.py deleted file mode 100644 index 0abfb0724..000000000 --- a/aikido_zen/background_process/statistics/init_test.py +++ /dev/null @@ -1,428 +0,0 @@ -import pytest -from unittest.mock import MagicMock -from aikido_zen.helpers.get_current_unixtime_ms import get_unixtime_ms -from . import Statistics - - -@pytest.fixture -def stats(): - """Fixture to create a new instance of Statistics.""" - return Statistics(max_perf_samples_in_mem=50, max_compressed_stats_in_mem=5) - - -def test_it_resets_stats(stats, monkeypatch): - monkeypatch.setattr( - "aikido_zen.helpers.get_current_unixtime_ms.get_unixtime_ms", lambda: 2 - ) - stats.on_inspected_call( - without_context=False, - sink="mongodb", - blocked=False, - duration_in_ms=0.1, - attack_detected=False, - ) - started_at = stats.get_stats()["startedAt"] - - assert stats.get_stats() == { - "sinks": { - "mongodb": { - "attacksDetected": { - "total": 0, - "blocked": 0, - }, - "interceptorThrewError": 0, - "withoutContext": 0, - "total": 1, - "compressedTimings": [], - }, - }, - "startedAt": started_at, - "endedAt": 2, - "requests": { - "total": 0, - "aborted": 0, - "attacksDetected": { - "total": 0, - "blocked": 0, - }, - }, - } - - stats.reset() - started_at = stats.get_stats()["startedAt"] - - assert stats.get_stats() == { - "sinks": {}, - "startedAt": started_at, # Assuming reset sets this to the current time - "endedAt": 2, - "requests": { - "total": 0, - "aborted": 0, - "attacksDetected": { - "total": 0, - "blocked": 0, - }, - }, - } - - -def test_it_keeps_track_of_amount_of_calls(stats, monkeypatch): - monkeypatch.setattr( - "aikido_zen.helpers.get_current_unixtime_ms.get_unixtime_ms", lambda: 2 - ) - started_at = stats.get_stats()["startedAt"] - assert stats.get_stats() == { - "sinks": {}, - "startedAt": started_at, - "endedAt": 2, - "requests": { - "total": 0, - "aborted": 0, - "attacksDetected": { - "total": 0, - "blocked": 0, - }, - }, - } - - stats.on_inspected_call( - without_context=False, - sink="mongodb", - blocked=False, - duration_in_ms=0.1, - attack_detected=False, - ) - - assert stats.get_stats() == { - "sinks": { - "mongodb": { - "attacksDetected": { - "total": 0, - "blocked": 0, - }, - "interceptorThrewError": 0, - "withoutContext": 0, - "total": 1, - "compressedTimings": [], - }, - }, - "startedAt": started_at, - "endedAt": 2, - "requests": { - "total": 0, - "aborted": 0, - "attacksDetected": { - "total": 0, - "blocked": 0, - }, - }, - } - - stats.on_inspected_call( - without_context=True, - sink="mongodb", - blocked=False, - duration_in_ms=0.1, - attack_detected=False, - ) - - assert stats.get_stats() == { - "sinks": { - "mongodb": { - "attacksDetected": { - "total": 0, - "blocked": 0, - }, - "interceptorThrewError": 0, - "withoutContext": 1, - "total": 2, - "compressedTimings": [], - }, - }, - "startedAt": started_at, - "endedAt": 2, - "requests": { - "total": 0, - "aborted": 0, - "attacksDetected": { - "total": 0, - "blocked": 0, - }, - }, - } - - stats.interceptor_threw_error("mongodb") - - assert stats.get_stats() == { - "sinks": { - "mongodb": { - "attacksDetected": { - "total": 0, - "blocked": 0, - }, - "interceptorThrewError": 1, - "withoutContext": 1, - "total": 3, - "compressedTimings": [], - }, - }, - "startedAt": started_at, - "endedAt": 2, - "requests": { - "total": 0, - "aborted": 0, - "attacksDetected": { - "total": 0, - "blocked": 0, - }, - }, - } - - stats.on_inspected_call( - without_context=False, - sink="mongodb", - blocked=False, - duration_in_ms=0.1, - attack_detected=True, - ) - - assert stats.get_stats() == { - "sinks": { - "mongodb": { - "attacksDetected": { - "total": 1, - "blocked": 0, - }, - "interceptorThrewError": 1, - "withoutContext": 1, - "total": 4, - "compressedTimings": [], - }, - }, - "startedAt": started_at, - "endedAt": 2, - "requests": { - "total": 0, - "aborted": 0, - "attacksDetected": { - "total": 0, - "blocked": 0, - }, - }, - } - - stats.on_inspected_call( - without_context=False, - sink="mongodb", - blocked=True, - duration_in_ms=0.3, - attack_detected=True, - ) - - assert stats.get_stats() == { - "sinks": { - "mongodb": { - "attacksDetected": { - "total": 2, - "blocked": 1, - }, - "interceptorThrewError": 1, - "withoutContext": 1, - "total": 5, - "compressedTimings": [], - }, - }, - "startedAt": started_at, - "endedAt": 2, - "requests": { - "total": 0, - "aborted": 0, - "attacksDetected": { - "total": 0, - "blocked": 0, - }, - }, - } - - assert stats.has_compressed_stats() is False - - for i in range(50): - stats.on_inspected_call( - without_context=False, - sink="mongodb", - blocked=False, - duration_in_ms=i * 0.1, - attack_detected=False, - ) - - assert stats.has_compressed_stats() is True - - # Check the compressed timings - assert len(stats.get_stats()["sinks"]["mongodb"]["compressedTimings"]) == 1 - - -def test_it_keeps_track_of_requests(stats, monkeypatch): - monkeypatch.setattr( - "aikido_zen.helpers.get_current_unixtime_ms.get_unixtime_ms", lambda: 9 - ) - started_at = stats.get_stats()["startedAt"] - - assert stats.get_stats() == { - "sinks": {}, - "startedAt": started_at, - "endedAt": 9, - "requests": { - "total": 0, - "aborted": 0, - "attacksDetected": { - "total": 0, - "blocked": 0, - }, - }, - } - - stats.requests["total"] += 1 - - assert stats.get_stats() == { - "sinks": {}, - "startedAt": started_at, - "endedAt": 9, - "requests": { - "total": 1, - "aborted": 0, - "attacksDetected": { - "total": 0, - "blocked": 0, - }, - }, - } - - stats.requests["total"] += 1 - stats.on_detected_attack(blocked=False) - - assert stats.get_stats() == { - "sinks": {}, - "startedAt": started_at, - "endedAt": 9, - "requests": { - "total": 2, - "aborted": 0, - "attacksDetected": { - "total": 1, - "blocked": 0, - }, - }, - } - - stats.requests["total"] += 1 - stats.on_detected_attack(blocked=True) - - assert stats.get_stats() == { - "sinks": {}, - "startedAt": started_at, - "endedAt": 9, - "requests": { - "total": 3, - "aborted": 0, - "attacksDetected": { - "total": 2, - "blocked": 1, - }, - }, - } - - stats.reset() - started_at = stats.get_stats()["startedAt"] - - assert stats.get_stats() == { - "sinks": {}, - "startedAt": started_at, # Assuming reset sets this to the current time - "endedAt": 9, - "requests": { - "total": 0, - "aborted": 0, - "attacksDetected": { - "total": 0, - "blocked": 0, - }, - }, - } - - -def test_it_force_compresses_stats(stats): - stats.requests["total"] += 1 - - stats.on_inspected_call( - without_context=False, - sink="mongodb", - blocked=False, - duration_in_ms=0.1, - attack_detected=False, - ) - - assert stats.has_compressed_stats() is False - - stats.force_compress() - - assert stats.has_compressed_stats() is True - - -def test_it_keeps_track_of_aborted_requests(stats, monkeypatch): - monkeypatch.setattr( - "aikido_zen.helpers.get_current_unixtime_ms.get_unixtime_ms", lambda: 5 - ) - stats.requests["aborted"] += 1 - started_at = stats.get_stats()["startedAt"] - - assert stats.get_stats() == { - "sinks": {}, - "startedAt": started_at, - "endedAt": 5, - "requests": { - "total": 0, - "aborted": 1, - "attacksDetected": { - "total": 0, - "blocked": 0, - }, - }, - } - - -def test_is_empty_when_stats_are_empty(stats): - assert stats.is_empty() is True - - -def test_is_empty_when_requests_are_empty(stats): - stats.requests["total"] = 0 - stats.requests["attacksDetected"]["total"] = 0 - stats.stats = {} # Assuming stats is a dictionary - assert stats.is_empty() is True - - -def test_is_empty_when_requests_have_data(stats): - stats.requests["total"] = 1 - stats.requests["attacksDetected"]["total"] = 0 - stats.stats = {} - assert stats.is_empty() is False - - -def test_is_empty_when_attacks_detected(stats): - stats.requests["total"] = 0 - stats.requests["attacksDetected"]["total"] = 1 - stats.stats = {} - assert stats.is_empty() is False - - -def test_is_empty_when_stats_have_data(stats): - stats.requests["total"] = 0 - stats.requests["attacksDetected"]["total"] = 0 - stats.stats = {"some_stat": 1} # Adding some data to stats - assert stats.is_empty() is False - - -def test_is_empty_when_all_data_present(stats): - stats.requests["total"] = 1 - stats.requests["attacksDetected"]["total"] = 1 - stats.stats = {"some_stat": 1} - assert stats.is_empty() is False diff --git a/aikido_zen/background_process/statistics/on_inspected_call.py b/aikido_zen/background_process/statistics/on_inspected_call.py deleted file mode 100644 index f325c7f66..000000000 --- a/aikido_zen/background_process/statistics/on_inspected_call.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Exports `on_inspected_call`""" - - -def on_inspected_call( - statistics_object, sink, duration_in_ms, attack_detected, blocked, without_context -): - """Handle an inspected call and update statistics accordingly.""" - statistics_object.ensure_sink_stats(sink) - - statistics_object.stats[sink]["total"] += 1 - - if without_context: - statistics_object.stats[sink]["withoutContext"] += 1 - return - - if ( - len(statistics_object.stats[sink]["durations"]) - >= statistics_object.max_perf_samples_in_mem - ): - statistics_object.compress_perf_samples(sink) - - statistics_object.stats[sink]["durations"].append(duration_in_ms) - - if attack_detected: - statistics_object.stats[sink]["attacksDetected"]["total"] += 1 - if blocked: - statistics_object.stats[sink]["attacksDetected"]["blocked"] += 1 diff --git a/aikido_zen/background_process/statistics/on_inspected_call_test.py b/aikido_zen/background_process/statistics/on_inspected_call_test.py deleted file mode 100644 index 178be5942..000000000 --- a/aikido_zen/background_process/statistics/on_inspected_call_test.py +++ /dev/null @@ -1,93 +0,0 @@ -import pytest -from unittest.mock import MagicMock -from .on_inspected_call import on_inspected_call - - -def test_no_context(): - """Test when without_context is True.""" - statistics_object = MagicMock() - statistics_object.stats = { - "sink1": { - "total": 0, - "withoutContext": 0, - "durations": [], - "attacksDetected": {"total": 0, "blocked": 0}, - } - } - statistics_object.max_perf_samples_in_mem = 5 - - on_inspected_call(statistics_object, "sink1", 150, False, False, True) - - assert statistics_object.stats["sink1"]["total"] == 1 - assert statistics_object.stats["sink1"]["withoutContext"] == 1 - assert statistics_object.stats["sink1"]["durations"] == [] - - -def test_with_context_and_no_attack(): - """Test when there is context and no attack detected.""" - statistics_object = MagicMock() - statistics_object.stats = { - "sink1": { - "total": 0, - "withoutContext": 0, - "durations": [], - "attacksDetected": {"total": 0, "blocked": 0}, - } - } - statistics_object.max_perf_samples_in_mem = 5 - - on_inspected_call(statistics_object, "sink1", 150, False, False, False) - - assert statistics_object.stats["sink1"]["total"] == 1 - assert statistics_object.stats["sink1"]["withoutContext"] == 0 - assert statistics_object.stats["sink1"]["durations"] == [150] - assert statistics_object.stats["sink1"]["attacksDetected"]["total"] == 0 - - -def test_with_context_and_attack_detected(): - """Test when there is context and an attack is detected.""" - statistics_object = MagicMock() - statistics_object.stats = { - "sink1": { - "total": 0, - "withoutContext": 0, - "durations": [], - "attacksDetected": {"total": 0, "blocked": 0}, - } - } - statistics_object.max_perf_samples_in_mem = 5 - - on_inspected_call(statistics_object, "sink1", 200, True, True, False) - - assert statistics_object.stats["sink1"]["total"] == 1 - assert statistics_object.stats["sink1"]["durations"] == [200] - assert statistics_object.stats["sink1"]["attacksDetected"]["total"] == 1 - assert statistics_object.stats["sink1"]["attacksDetected"]["blocked"] == 1 - - -def test_compress_samples_when_limit_exceeded(): - """Test when the number of durations exceeds the limit.""" - statistics_object = MagicMock() - statistics_object.stats = { - "sink1": { - "total": 0, - "withoutContext": 0, - "durations": [100, 200, 300, 400, 500], - "attacksDetected": {"total": 0, "blocked": 0}, - } - } - statistics_object.max_perf_samples_in_mem = 5 - - # Mock the compress_perf_samples function to avoid actual compression - statistics_object.compress_perf_samples = MagicMock() - - on_inspected_call(statistics_object, "sink1", 600, False, False, False) - - assert statistics_object.stats["sink1"]["durations"] == [ - 100, - 200, - 300, - 400, - 500, - 600, - ] diff --git a/aikido_zen/sources/functions/request_handler.py b/aikido_zen/sources/functions/request_handler.py index cbec77647..3250ca99d 100644 --- a/aikido_zen/sources/functions/request_handler.py +++ b/aikido_zen/sources/functions/request_handler.py @@ -13,9 +13,9 @@ def request_handler(stage, status_code=0): """This will check for rate limiting, Allowed IP's, useful routes, etc.""" try: if stage == "init": - thread_cache = get_cache() - if ctx.get_current_context() and thread_cache: - thread_cache.increment_stats() # Increment request statistics if a context exists. + cache = get_cache() + if ctx.get_current_context() and cache: + cache.stats.increment_total_hits() if stage == "pre_response": return pre_response() if stage == "post_response": diff --git a/aikido_zen/storage/statistics/__init__.py b/aikido_zen/storage/statistics/__init__.py new file mode 100644 index 000000000..93a08e693 --- /dev/null +++ b/aikido_zen/storage/statistics/__init__.py @@ -0,0 +1,56 @@ +import aikido_zen.helpers.get_current_unixtime_ms as t + + +class Statistics: + """ + Keeps track of total and aborted requests + and total and blocked attacks + """ + + def __init__(self): + self.total_hits = 0 + self.attacks_detected = 0 + self.attacks_blocked = 0 + self.started_at = t.get_unixtime_ms() + + def clear(self): + self.total_hits = 0 + self.attacks_detected = 0 + self.attacks_blocked = 0 + self.started_at = t.get_unixtime_ms() + + def increment_total_hits(self): + self.total_hits += 1 + + def on_detected_attack(self, blocked): + self.attacks_detected += 1 + if blocked: + self.attacks_blocked += 1 + + def get_record(self): + current_time = t.get_unixtime_ms() + return { + "startedAt": self.started_at, + "endedAt": current_time, + "requests": { + "total": self.total_hits, + "aborted": 0, # statistic currently not in use + "attacksDetected": { + "total": self.attacks_detected, + "blocked": self.attacks_blocked, + }, + }, + } + + def import_from_record(self, record): + attacks_detected = record.get("requests", {}).get("attacksDetected", {}) + self.total_hits += record.get("requests", {}).get("total", 0) + self.attacks_detected += attacks_detected.get("total", 0) + self.attacks_blocked += attacks_detected.get("blocked", 0) + + def empty(self): + if self.total_hits > 0: + return False + if self.attacks_detected > 0: + return False + return True diff --git a/aikido_zen/storage/statistics/init_test.py b/aikido_zen/storage/statistics/init_test.py new file mode 100644 index 000000000..3ea803668 --- /dev/null +++ b/aikido_zen/storage/statistics/init_test.py @@ -0,0 +1,169 @@ +import pytest +from . import Statistics + + +def test_initialization(monkeypatch): + # Mock the current time + mock_time = 1234567890000 + monkeypatch.setattr( + "aikido_zen.helpers.get_current_unixtime_ms.get_unixtime_ms", lambda: mock_time + ) + + stats = Statistics() + assert stats.total_hits == 0 + assert stats.attacks_detected == 0 + assert stats.attacks_blocked == 0 + assert stats.started_at == mock_time + + +def test_clear(monkeypatch): + # Mock the current time + mock_time = 1234567890000 + monkeypatch.setattr( + "aikido_zen.helpers.get_current_unixtime_ms.get_unixtime_ms", lambda: mock_time + ) + + stats = Statistics() + stats.total_hits = 10 + stats.attacks_detected = 5 + stats.attacks_blocked = 3 + stats.clear() + + assert stats.total_hits == 0 + assert stats.attacks_detected == 0 + assert stats.attacks_blocked == 0 + assert stats.started_at == mock_time + + +def test_increment_total_hits(): + stats = Statistics() + stats.increment_total_hits() + assert stats.total_hits == 1 + + +def test_on_detected_attack(): + stats = Statistics() + stats.on_detected_attack(blocked=True) + assert stats.attacks_detected == 1 + assert stats.attacks_blocked == 1 + + stats.on_detected_attack(blocked=False) + assert stats.attacks_detected == 2 + assert stats.attacks_blocked == 1 + + +def test_get_record(monkeypatch): + # Mock the current time + mock_time = 1234567890000 + monkeypatch.setattr( + "aikido_zen.helpers.get_current_unixtime_ms.get_unixtime_ms", lambda: mock_time + ) + + stats = Statistics() + stats.total_hits = 10 + stats.attacks_detected = 5 + stats.attacks_blocked = 3 + + record = stats.get_record() + assert record["startedAt"] == stats.started_at + assert record["endedAt"] == mock_time + assert record["requests"]["total"] == 10 + assert record["requests"]["aborted"] == 0 + assert record["requests"]["attacksDetected"]["total"] == 5 + assert record["requests"]["attacksDetected"]["blocked"] == 3 + + +def test_import_from_record(): + stats = Statistics() + record = { + "requests": { + "total": 10, + "attacksDetected": { + "total": 5, + "blocked": 3, + }, + } + } + stats.import_from_record(record) + assert stats.total_hits == 10 + assert stats.attacks_detected == 5 + assert stats.attacks_blocked == 3 + + +def test_empty(): + stats = Statistics() + assert stats.empty() == True + + stats.total_hits = 1 + assert stats.empty() == False + + stats.total_hits = 0 + stats.attacks_detected = 1 + assert stats.empty() == False + + +def test_multiple_imports(): + stats = Statistics() + record1 = { + "requests": { + "total": 10, + "attacksDetected": { + "total": 5, + "blocked": 3, + }, + } + } + record2 = { + "requests": { + "total": 20, + "attacksDetected": { + "total": 10, + "blocked": 7, + }, + } + } + stats.import_from_record(record1) + stats.import_from_record(record2) + assert stats.total_hits == 30 + assert stats.attacks_detected == 15 + assert stats.attacks_blocked == 10 + + +def test_import_empty_record(): + stats = Statistics() + record = {"requests": {}} + stats.import_from_record(record) + assert stats.total_hits == 0 + assert stats.attacks_detected == 0 + assert stats.attacks_blocked == 0 + + +def test_import_partial_record(): + stats = Statistics() + record = {"requests": {"total": 10}} + stats.import_from_record(record) + assert stats.total_hits == 10 + assert stats.attacks_detected == 0 + assert stats.attacks_blocked == 0 + + +def test_increment_and_detect(): + stats = Statistics() + stats.increment_total_hits() + stats.on_detected_attack(blocked=True) + assert stats.total_hits == 1 + assert stats.attacks_detected == 1 + assert stats.attacks_blocked == 1 + + +def test_multiple_increments_and_detects(): + stats = Statistics() + for _ in range(10): + stats.increment_total_hits() + for _ in range(5): + stats.on_detected_attack(blocked=True) + for _ in range(5): + stats.on_detected_attack(blocked=False) + assert stats.total_hits == 10 + assert stats.attacks_detected == 10 + assert stats.attacks_blocked == 5 diff --git a/aikido_zen/thread/thread_cache.py b/aikido_zen/thread/thread_cache.py index 6e6940144..12a83d211 100644 --- a/aikido_zen/thread/thread_cache.py +++ b/aikido_zen/thread/thread_cache.py @@ -6,6 +6,7 @@ from aikido_zen.context import get_current_context from aikido_zen.helpers.logging import logger from aikido_zen.storage.hostnames import Hostnames +from aikido_zen.storage.statistics import Statistics from aikido_zen.storage.users import Users from aikido_zen.thread import process_worker_loader @@ -18,6 +19,7 @@ class ThreadCache: def __init__(self): self.hostnames = Hostnames(200) self.users = Users(1000) + self.stats = Statistics() self.reset() # Initialize values def is_bypassed_ip(self, ip): @@ -41,10 +43,10 @@ def reset(self): last_updated_at=-1, received_any_stats=False, ) - self.reqs = 0 self.middleware_installed = False self.hostnames.clear() self.users.clear() + self.stats.clear() def renew(self): if not comms.get_comms(): @@ -55,10 +57,10 @@ def renew(self): action="SYNC_DATA", obj={ "current_routes": self.routes.get_routes_with_hits(), - "reqs": self.reqs, "middleware_installed": self.middleware_installed, "hostnames": self.hostnames.as_array(), "users": self.users.as_array(), + "stats": self.stats.get_record(), }, receive=True, ) @@ -76,10 +78,6 @@ def renew(self): for route in self.routes.routes.values(): route["hits_delta_since_sync"] = 0 - def increment_stats(self): - """Increments the requests""" - self.reqs += 1 - # For these 2 functions and the data they process, we rely on Python's GIL # See here: https://wiki.python.org/moin/GlobalInterpreterLock diff --git a/aikido_zen/thread/thread_cache_test.py b/aikido_zen/thread/thread_cache_test.py index 3eca67ce3..08846ad94 100644 --- a/aikido_zen/thread/thread_cache_test.py +++ b/aikido_zen/thread/thread_cache_test.py @@ -11,7 +11,10 @@ @pytest.fixture def thread_cache(): """Fixture to create a ThreadCache instance.""" - return ThreadCache() + with patch( + "aikido_zen.helpers.get_current_unixtime_ms.get_unixtime_ms", return_value=-1 + ): + return ThreadCache() class Context2(Context): @@ -36,7 +39,11 @@ def test_initialization(thread_cache: ThreadCache): assert isinstance(thread_cache.config.bypassed_ips, IPList) assert thread_cache.get_endpoints() == [] assert thread_cache.config.blocked_uids == set() - assert thread_cache.reqs == 0 + assert thread_cache.stats.get_record()["requests"] == { + "total": 0, + "aborted": 0, + "attacksDetected": {"total": 0, "blocked": 0}, + } def test_is_bypassed_ip(thread_cache: ThreadCache): @@ -60,20 +67,27 @@ def test_reset(thread_cache: ThreadCache): """Test that reset empties the cache.""" thread_cache.config.bypassed_ips.add("192.168.1.1") thread_cache.config.blocked_uids.add("user123") + thread_cache.stats.increment_total_hits() + thread_cache.stats.on_detected_attack(blocked=True) + thread_cache.reset() assert isinstance(thread_cache.config.bypassed_ips, IPList) assert thread_cache.config.blocked_uids == set() - assert thread_cache.reqs == 0 + assert thread_cache.stats.get_record()["requests"] == { + "total": 0, + "aborted": 0, + "attacksDetected": {"total": 0, "blocked": 0}, + } -def test_increment_stats(thread_cache): +def test_increment_total_hits(thread_cache): """Test that incrementing stats works correctly.""" - assert thread_cache.reqs == 0 - thread_cache.increment_stats() - assert thread_cache.reqs == 1 - thread_cache.increment_stats() - assert thread_cache.reqs == 2 + assert thread_cache.stats.get_record()["requests"]["total"] == 0 + thread_cache.stats.increment_total_hits() + assert thread_cache.stats.get_record()["requests"]["total"] == 1 + thread_cache.stats.increment_total_hits() + assert thread_cache.stats.get_record()["requests"]["total"] == 2 def test_renew_with_no_comms(thread_cache: ThreadCache): @@ -83,7 +97,11 @@ def test_renew_with_no_comms(thread_cache: ThreadCache): assert isinstance(thread_cache.config.bypassed_ips, IPList) assert thread_cache.get_endpoints() == [] assert thread_cache.config.blocked_uids == set() - assert thread_cache.reqs == 0 + assert thread_cache.stats.get_record()["requests"] == { + "total": 0, + "aborted": 0, + "attacksDetected": {"total": 0, "blocked": 0}, + } @patch("aikido_zen.background_process.comms.get_comms") @@ -117,7 +135,7 @@ def test_increment_stats_thread_safety(thread_cache): def increment_in_thread(): for _ in range(100): - thread_cache.increment_stats() + thread_cache.stats.increment_total_hits() threads = [Thread(target=increment_in_thread) for _ in range(10)] for thread in threads: @@ -125,7 +143,9 @@ def increment_in_thread(): for thread in threads: thread.join() - assert thread_cache.reqs == 1000 # 10 threads incrementing 100 times + assert ( + thread_cache.stats.get_record()["requests"]["total"] == 1000 + ) # 10 threads incrementing 100 times @patch("aikido_zen.background_process.comms.get_comms") @@ -214,12 +234,19 @@ def test_renew_called_with_correct_args(mock_get_comms, thread_cache: ThreadCach mock_get_comms.return_value = mock_comms # Setup initial state - thread_cache.increment_stats() + thread_cache.stats.increment_total_hits() + thread_cache.stats.increment_total_hits() + thread_cache.stats.on_detected_attack(blocked=True) + thread_cache.stats.on_detected_attack(blocked=False) + thread_cache.stats.on_detected_attack(blocked=False) thread_cache.routes.initialize_route({"method": "GET", "route": "/test"}) thread_cache.routes.increment_route({"method": "GET", "route": "/test"}) # Call renew - thread_cache.renew() + with patch( + "aikido_zen.helpers.get_current_unixtime_ms.get_unixtime_ms", return_value=-1 + ): + thread_cache.renew() # Assert that send_data_to_bg_process was called with the correct arguments mock_comms.send_data_to_bg_process.assert_called_once_with( @@ -234,7 +261,15 @@ def test_renew_called_with_correct_args(mock_get_comms, thread_cache: ThreadCach "apispec": {}, } }, - "reqs": 1, + "stats": { + "startedAt": -1, + "endedAt": -1, + "requests": { + "total": 2, + "aborted": 0, + "attacksDetected": {"blocked": 1, "total": 3}, + }, + }, "middleware_installed": False, "hostnames": [], "users": [], @@ -251,7 +286,7 @@ def test_sync_data_for_users(mock_get_comms, thread_cache: ThreadCache): Context2().set_as_current_context() # Setup initial state - thread_cache.increment_stats() + thread_cache.stats.increment_total_hits() with patch("aikido_zen.thread.thread_cache.get_cache", return_value=thread_cache): with patch( "aikido_zen.helpers.get_current_unixtime_ms.get_unixtime_ms", return_value=1 @@ -260,14 +295,25 @@ def test_sync_data_for_users(mock_get_comms, thread_cache: ThreadCache): set_user({"id": "567", "name": "test"}) # Call renew - thread_cache.renew() + with patch( + "aikido_zen.helpers.get_current_unixtime_ms.get_unixtime_ms", return_value=-1 + ): + thread_cache.renew() # Assert that send_data_to_bg_process was called with the correct arguments mock_comms.send_data_to_bg_process.assert_called_once_with( action="SYNC_DATA", obj={ "current_routes": {}, - "reqs": 1, + "stats": { + "startedAt": -1, + "endedAt": -1, + "requests": { + "total": 1, + "aborted": 0, + "attacksDetected": {"total": 0, "blocked": 0}, + }, + }, "middleware_installed": False, "hostnames": [], "users": [ @@ -298,14 +344,25 @@ def test_renew_called_with_empty_routes(mock_get_comms, thread_cache: ThreadCach mock_get_comms.return_value = mock_comms # Call renew without initializing any routes - thread_cache.renew() + with patch( + "aikido_zen.helpers.get_current_unixtime_ms.get_unixtime_ms", return_value=-1 + ): + thread_cache.renew() # Assert that send_data_to_bg_process was called with the correct arguments mock_comms.send_data_to_bg_process.assert_called_once_with( action="SYNC_DATA", obj={ "current_routes": {}, - "reqs": 0, + "stats": { + "startedAt": -1, + "endedAt": -1, + "requests": { + "total": 0, + "aborted": 0, + "attacksDetected": {"total": 0, "blocked": 0}, + }, + }, "middleware_installed": False, "hostnames": [], "users": [], @@ -324,14 +381,25 @@ def test_renew_called_with_no_requests(mock_get_comms, thread_cache: ThreadCache thread_cache.routes.initialize_route({"method": "GET", "route": "/test"}) # Call renew - thread_cache.renew() + with patch( + "aikido_zen.helpers.get_current_unixtime_ms.get_unixtime_ms", return_value=-1 + ): + thread_cache.renew() # Assert that send_data_to_bg_process was called with the correct arguments mock_comms.send_data_to_bg_process.assert_called_once_with( action="SYNC_DATA", obj={ "current_routes": {}, - "reqs": 0, + "stats": { + "startedAt": -1, + "endedAt": -1, + "requests": { + "total": 0, + "aborted": 0, + "attacksDetected": {"total": 0, "blocked": 0}, + }, + }, "middleware_installed": False, "hostnames": [], "users": [], diff --git a/aikido_zen/vulnerabilities/__init__.py b/aikido_zen/vulnerabilities/__init__.py index 601991b98..006a5724b 100644 --- a/aikido_zen/vulnerabilities/__init__.py +++ b/aikido_zen/vulnerabilities/__init__.py @@ -97,8 +97,12 @@ def run_vulnerability_scan(kind, op, args): if injection_results: logger.debug("Injection results : %s", serialize_to_json(injection_results)) + blocked = is_blocking_enabled() + thread_cache.stats.on_detected_attack(blocked) + stack = get_clean_stacktrace() + if comms: comms.send_data_to_bg_process( "ATTACK", (injection_results, context, blocked, stack) diff --git a/aikido_zen/vulnerabilities/init_test.py b/aikido_zen/vulnerabilities/init_test.py index 954213787..d502a9a18 100644 --- a/aikido_zen/vulnerabilities/init_test.py +++ b/aikido_zen/vulnerabilities/init_test.py @@ -93,12 +93,30 @@ def test_lifecycle_cache_bypassed_ip(caplog, get_context): def test_sql_injection(caplog, get_context, monkeypatch): get_context.set_as_current_context() monkeypatch.setenv("AIKIDO_BLOCK", "1") + + assert get_cache().stats.get_record()["requests"]["attacksDetected"]["total"] == 0 with pytest.raises(AikidoSQLInjection): run_vulnerability_scan( kind="sql_injection", op="test_op", args=("INSERT * INTO VALUES ('doggoss2', TRUE);", "mysql"), ) + assert get_cache().stats.get_record()["requests"]["attacksDetected"]["total"] == 1 + assert get_cache().stats.get_record()["requests"]["attacksDetected"]["blocked"] == 1 + + +def test_sql_injection_but_blocking_off(caplog, get_context, monkeypatch): + get_context.set_as_current_context() + monkeypatch.setenv("AIKIDO_BLOCK", "0") + + assert get_cache().stats.get_record()["requests"]["attacksDetected"]["total"] == 0 + run_vulnerability_scan( + kind="sql_injection", + op="test_op", + args=("INSERT * INTO VALUES ('doggoss2', TRUE);", "mysql"), + ) + assert get_cache().stats.get_record()["requests"]["attacksDetected"]["total"] == 1 + assert get_cache().stats.get_record()["requests"]["attacksDetected"]["blocked"] == 0 def test_sql_injection_with_route_params(caplog, get_context, monkeypatch): @@ -162,7 +180,9 @@ def test_ssrf_vulnerability_scan_no_port(get_context): hostname = "example.com" port = 0 # Port is zero, should not add to hostnames + assert get_cache().stats.get_record()["requests"]["attacksDetected"]["total"] == 0 run_vulnerability_scan(kind="ssrf", op="test", args=(dns_results, hostname, port)) + assert get_cache().stats.get_record()["requests"]["attacksDetected"]["total"] == 0 assert get_cache().hostnames.as_array() == [] @@ -176,7 +196,9 @@ def test_ssrf_vulnerability_scan_bypassed_ip(get_context): hostname = "example.com" port = 80 + assert get_cache().stats.get_record()["requests"]["attacksDetected"]["total"] == 0 run_vulnerability_scan(kind="ssrf", op="test", args=(dns_results, hostname, port)) + assert get_cache().stats.get_record()["requests"]["attacksDetected"]["total"] == 0 # Verify that hostnames.add was not called due to bypassed IP assert get_cache().hostnames.as_array() == []