diff --git a/aikido_zen/background_process/commands/sync_data_test.py b/aikido_zen/background_process/commands/sync_data_test.py index b12451a02..4827947f3 100644 --- a/aikido_zen/background_process/commands/sync_data_test.py +++ b/aikido_zen/background_process/commands/sync_data_test.py @@ -54,6 +54,7 @@ def test_process_sync_data_initialization(setup_connection_manager): "endedAt": 1, "requests": { "total": 10, + "rateLimited": 0, "aborted": 0, "attacksDetected": { "total": 5, @@ -94,6 +95,7 @@ def test_process_sync_data_initialization(setup_connection_manager): "aborted": 0, "attacksDetected": {"blocked": 0, "total": 5}, "total": 10, + "rateLimited": 0, } # Check that the return value is correct @@ -135,6 +137,7 @@ def test_process_sync_data_with_last_updated_at_below_zero(setup_connection_mana "endedAt": 1, "requests": { "total": 10, + "rateLimited": 0, "aborted": 0, "attacksDetected": { "total": 5, @@ -167,6 +170,7 @@ def test_process_sync_data_with_last_updated_at_below_zero(setup_connection_mana "aborted": 0, "attacksDetected": {"blocked": 0, "total": 5}, "total": 10, + "rateLimited": 0, } assert connection_manager.middleware_installed == True assert len(connection_manager.hostnames.as_array()) == 0 @@ -199,6 +203,7 @@ def test_process_sync_data_existing_route_and_hostnames(setup_connection_manager "endedAt": 1, "requests": { "total": 5, + "rateLimited": 0, "aborted": 0, "attacksDetected": { "total": 5, @@ -227,6 +232,7 @@ def test_process_sync_data_existing_route_and_hostnames(setup_connection_manager "endedAt": 1, "requests": { "total": 15, + "rateLimited": 0, "aborted": 0, "attacksDetected": { "total": 5, @@ -251,6 +257,7 @@ def test_process_sync_data_existing_route_and_hostnames(setup_connection_manager "aborted": 0, "attacksDetected": {"blocked": 0, "total": 10}, "total": 20, + "rateLimited": 0, } assert connection_manager.middleware_installed == False assert connection_manager.hostnames.as_array() == [ diff --git a/aikido_zen/middleware/init_test.py b/aikido_zen/middleware/init_test.py index 8ee6f2a25..22716deb6 100644 --- a/aikido_zen/middleware/init_test.py +++ b/aikido_zen/middleware/init_test.py @@ -61,6 +61,7 @@ def test_with_context_with_cache(): } assert get_current_context().executed_middleware == True assert thread_cache.middleware_installed == True + assert thread_cache.stats.rate_limited_hits == 0 thread_cache.config.blocked_uids = [] assert should_block_request() == {"block": False} @@ -69,6 +70,7 @@ def test_with_context_with_cache(): assert should_block_request() == {"block": False} assert get_current_context().executed_middleware == True assert thread_cache.middleware_installed == True + assert thread_cache.stats.rate_limited_hits == 0 def test_cache_comms_with_endpoints(): @@ -158,9 +160,11 @@ def test_cache_comms_with_endpoints(): "success": True, "data": {"block": True, "trigger": "my_trigger"}, } + assert thread_cache.stats.rate_limited_hits == 0 assert should_block_request() == { "block": True, "ip": "::1", "type": "ratelimited", "trigger": "my_trigger", } + assert thread_cache.stats.rate_limited_hits == 1 diff --git a/aikido_zen/middleware/should_block_request.py b/aikido_zen/middleware/should_block_request.py index 20e4d1435..360d5fde9 100644 --- a/aikido_zen/middleware/should_block_request.py +++ b/aikido_zen/middleware/should_block_request.py @@ -51,6 +51,7 @@ def should_block_request(): receive=True, ) if ratelimit_res["success"] and ratelimit_res["data"]["block"]: + cache.stats.on_rate_limit() return { "block": True, "type": "ratelimited", diff --git a/aikido_zen/storage/statistics/__init__.py b/aikido_zen/storage/statistics/__init__.py index bedf42ca6..47e7dec3e 100644 --- a/aikido_zen/storage/statistics/__init__.py +++ b/aikido_zen/storage/statistics/__init__.py @@ -4,14 +4,15 @@ class Statistics: """ - Keeps track of total and aborted requests - and total and blocked attacks + Stores: hits, counts of attacks (split up in detected/blocked), count of rate-limited requests, + statistics for operations (i.e. how many times did we see a query being executed) """ def __init__(self): self.total_hits = 0 self.attacks_detected = 0 self.attacks_blocked = 0 + self.rate_limited_hits = 0 self.started_at = t.get_unixtime_ms() self.operations = Operations() @@ -19,6 +20,7 @@ def clear(self): self.total_hits = 0 self.attacks_detected = 0 self.attacks_blocked = 0 + self.rate_limited_hits = 0 self.started_at = t.get_unixtime_ms() self.operations.clear() @@ -31,6 +33,9 @@ def on_detected_attack(self, blocked, operation): self.attacks_blocked += 1 self.operations.on_detected_attack(blocked, operation) + def on_rate_limit(self): + self.rate_limited_hits += 1 + def get_record(self): current_time = t.get_unixtime_ms() return { @@ -38,6 +43,7 @@ def get_record(self): "endedAt": current_time, "requests": { "total": self.total_hits, + "rateLimited": self.rate_limited_hits, "aborted": 0, # statistic currently not in use "attacksDetected": { "total": self.attacks_detected, @@ -50,6 +56,7 @@ def get_record(self): def import_from_record(self, record): attacks_detected = record.get("requests", {}).get("attacksDetected", {}) self.total_hits += record.get("requests", {}).get("total", 0) + self.rate_limited_hits += record.get("requests", {}).get("rateLimited", 0) self.attacks_detected += attacks_detected.get("total", 0) self.attacks_blocked += attacks_detected.get("blocked", 0) self.operations.update(record.get("operations", {})) diff --git a/aikido_zen/storage/statistics/init_test.py b/aikido_zen/storage/statistics/init_test.py index 5ca481d06..7d54ec1c3 100644 --- a/aikido_zen/storage/statistics/init_test.py +++ b/aikido_zen/storage/statistics/init_test.py @@ -68,6 +68,8 @@ def test_get_record(monkeypatch): stats = Statistics() stats.total_hits = 10 + stats.on_rate_limit() + stats.on_rate_limit() stats.operations.register_call("test.test", "nosql_op") stats.on_detected_attack(blocked=True, operation="test.test") stats.attacks_detected = 5 @@ -77,6 +79,7 @@ def test_get_record(monkeypatch): assert record["startedAt"] == stats.started_at assert record["endedAt"] == mock_time assert record["requests"]["total"] == 10 + assert record["requests"]["rateLimited"] == 2 assert record["requests"]["aborted"] == 0 assert record["requests"]["attacksDetected"]["total"] == 5 assert record["requests"]["attacksDetected"]["blocked"] == 3 @@ -97,6 +100,7 @@ def test_import_from_record(): record = { "requests": { "total": 10, + "rateLimited": 5, "attacksDetected": { "total": 5, "blocked": 3, @@ -117,6 +121,7 @@ def test_import_from_record(): } stats.import_from_record(record) assert stats.total_hits == 10 + assert stats.rate_limited_hits == 5 assert stats.attacks_detected == 5 assert stats.attacks_blocked == 3 assert stats.operations == { @@ -152,6 +157,7 @@ def test_multiple_imports(stats): record1 = { "requests": { "total": 10, + "rateLimited": 20, "attacksDetected": { "total": 5, "blocked": 3, @@ -168,6 +174,7 @@ def test_multiple_imports(stats): record2 = { "requests": { "total": 20, + "rateLimited": 5, "attacksDetected": { "total": 10, "blocked": 7, @@ -184,6 +191,7 @@ def test_multiple_imports(stats): stats.import_from_record(record1) stats.import_from_record(record2) assert stats.total_hits == 30 + assert stats.rate_limited_hits == 25 assert stats.attacks_detected == 15 assert stats.attacks_blocked == 10 assert stats.operations == { @@ -204,6 +212,7 @@ def test_import_empty_record(stats): record = {"requests": {}} stats.import_from_record(record) assert stats.total_hits == 0 + assert stats.rate_limited_hits == 0 assert stats.attacks_detected == 0 assert stats.attacks_blocked == 0 assert stats.operations == {} @@ -213,6 +222,7 @@ def test_import_partial_record(stats): record = {"requests": {"total": 10}} stats.import_from_record(record) assert stats.total_hits == 10 + assert stats.rate_limited_hits == 0 assert stats.attacks_detected == 0 assert stats.attacks_blocked == 0 assert stats.operations == {} @@ -242,3 +252,40 @@ def test_multiple_increments_and_detects(stats): "kind": "sql_op", "total": 1, } + + stats.on_rate_limit() + assert stats.rate_limited_hits == 1 + + stats.on_rate_limit() + assert stats.rate_limited_hits == 2 + + +def test_multiple_rate_limits(stats): + """Test multiple rate limit calls""" + for _ in range(5): + stats.on_rate_limit() + assert stats.rate_limited_hits == 5 + + +def test_rate_limit_in_get_record(): + """Test that rate_limited_hits is included in get_record output""" + stats = Statistics() + stats.total_hits = 10 + stats.on_rate_limit() + stats.on_rate_limit() + stats.on_rate_limit() + + record = stats.get_record() + assert record["requests"]["rateLimited"] == 3 + assert record["requests"]["total"] == 10 + + +def test_rate_limit_clear(): + """Test that clear() resets rate_limited_hits""" + stats = Statistics() + stats.on_rate_limit() + stats.on_rate_limit() + assert stats.rate_limited_hits == 2 + + stats.clear() + assert stats.rate_limited_hits == 0 diff --git a/aikido_zen/thread/thread_cache_test.py b/aikido_zen/thread/thread_cache_test.py index c64e6fad2..861124b66 100644 --- a/aikido_zen/thread/thread_cache_test.py +++ b/aikido_zen/thread/thread_cache_test.py @@ -42,6 +42,7 @@ def test_initialization(thread_cache: ThreadCache): assert thread_cache.config.blocked_uids == set() assert thread_cache.stats.get_record()["requests"] == { "total": 0, + "rateLimited": 0, "aborted": 0, "attacksDetected": {"total": 0, "blocked": 0}, } @@ -77,6 +78,7 @@ def test_reset(thread_cache: ThreadCache): assert thread_cache.config.blocked_uids == set() assert thread_cache.stats.get_record()["requests"] == { "total": 0, + "rateLimited": 0, "aborted": 0, "attacksDetected": {"total": 0, "blocked": 0}, } @@ -100,6 +102,7 @@ def test_renew_with_no_comms(thread_cache: ThreadCache): assert thread_cache.config.blocked_uids == set() assert thread_cache.stats.get_record()["requests"] == { "total": 0, + "rateLimited": 0, "aborted": 0, "attacksDetected": {"total": 0, "blocked": 0}, } @@ -285,6 +288,7 @@ def test_renew_called_with_correct_args(mock_get_comms, thread_cache: ThreadCach "endedAt": -1, "requests": { "total": 2, + "rateLimited": 0, "aborted": 0, "attacksDetected": {"blocked": 1, "total": 3}, }, @@ -369,6 +373,7 @@ def test_sync_data_for_users(mock_get_comms, thread_cache: ThreadCache): "endedAt": -1, "requests": { "total": 1, + "rateLimited": 0, "aborted": 0, "attacksDetected": {"total": 0, "blocked": 0}, }, @@ -421,6 +426,7 @@ def test_renew_called_with_empty_routes(mock_get_comms, thread_cache: ThreadCach "endedAt": -1, "requests": { "total": 0, + "rateLimited": 0, "aborted": 0, "attacksDetected": {"total": 0, "blocked": 0}, }, @@ -461,6 +467,7 @@ def test_renew_called_with_no_requests(mock_get_comms, thread_cache: ThreadCache "endedAt": -1, "requests": { "total": 0, + "rateLimited": 0, "aborted": 0, "attacksDetected": {"total": 0, "blocked": 0}, }, diff --git a/end2end/django_mysql_test.py b/end2end/django_mysql_test.py index db6060195..b293d1101 100644 --- a/end2end/django_mysql_test.py +++ b/end2end/django_mysql_test.py @@ -88,6 +88,6 @@ def test_initial_heartbeat(): "method": "POST", "path": "/app/create" }], - {"aborted":0,"attacksDetected":{"blocked":2,"total":2},"total":3}, + {"aborted":0,"attacksDetected":{"blocked":2,"total":2},"total":3, 'rateLimited': 0}, {'asgiref', 'regex', 'mysqlclient', 'sqlparse', 'aikido_zen', 'django'} )