diff --git a/aikido_zen/ratelimiting/__init__.py b/aikido_zen/ratelimiting/__init__.py index 0264137a0..d3f7e1b56 100644 --- a/aikido_zen/ratelimiting/__init__.py +++ b/aikido_zen/ratelimiting/__init__.py @@ -17,10 +17,7 @@ def should_ratelimit_request(route_metadata, remote_address, user, connection_ma max_requests = int(endpoint["rateLimiting"]["maxRequests"]) windows_size_in_ms = int(endpoint["rateLimiting"]["windowSizeInMS"]) - is_bypassed_ip = connection_manager.conf.is_bypassed_ip(remote_address) - if is_bypassed_ip: - return {"block": False} if user: uid = user["id"] method = endpoint.get("method") diff --git a/aikido_zen/ratelimiting/init_test.py b/aikido_zen/ratelimiting/init_test.py index 508ec338f..a7d8db317 100644 --- a/aikido_zen/ratelimiting/init_test.py +++ b/aikido_zen/ratelimiting/init_test.py @@ -61,39 +61,6 @@ def test_rate_limits_by_ip(): } -def test_rate_limiting_ip_allowed(): - cm = create_connection_manager( - [ - { - "method": "POST", - "route": "/login", - "forceProtectionOff": False, - "rateLimiting": { - "enabled": True, - "maxRequests": 3, - "windowSizeInMS": 1000, - }, - }, - ], - ["1.2.3.4"], - ) - - # Test requests from allowed IP - route_metadata = create_route_metadata() - assert should_ratelimit_request(route_metadata, "1.2.3.4", None, cm) == { - "block": False - } - assert should_ratelimit_request(route_metadata, "1.2.3.4", None, cm) == { - "block": False - } - assert should_ratelimit_request(route_metadata, "1.2.3.4", None, cm) == { - "block": False - } - assert should_ratelimit_request(route_metadata, "1.2.3.4", None, cm) == { - "block": False - } - - def test_rate_limiting_by_user(user): cm = create_connection_manager( [ @@ -304,47 +271,3 @@ def test_rate_limiting_same_ip_different_users(): assert should_ratelimit_request(metadata, "1.2.3.4", {"id": "123456"}, cm) == { "block": False } - - -def test_does_not_ratelimit_bypassed_ip_with_user(): - pass # Really? - - -def test_works_with_setuser_after_first_ratelimit(): - pass - - -import pytest - - -def test_rate_limiting_bypassed_ip_with_user(): - cm = create_connection_manager( - [ - { - "method": "POST", - "route": "/login", - "forceProtectionOff": False, - "rateLimiting": { - "enabled": True, - "maxRequests": 3, - "windowSizeInMS": 1000, - }, - }, - ], - ["1.2.3.4"], - ) - - # All requests from the bypassed IP should not be blocked - metadata = create_route_metadata(route="/login", method="POST") - assert should_ratelimit_request(metadata, "1.2.3.4", {"id": "123"}, cm) == { - "block": False - } - assert should_ratelimit_request(metadata, "1.2.3.4", {"id": "123"}, cm) == { - "block": False - } - assert should_ratelimit_request(metadata, "1.2.3.4", {"id": "123"}, cm) == { - "block": False - } - assert should_ratelimit_request(metadata, "1.2.3.4", {"id": "123"}, cm) == { - "block": False - } diff --git a/aikido_zen/sources/django/run_init_stage.py b/aikido_zen/sources/django/run_init_stage.py index 4b865c1ae..e91bb560a 100644 --- a/aikido_zen/sources/django/run_init_stage.py +++ b/aikido_zen/sources/django/run_init_stage.py @@ -3,11 +3,11 @@ import json from aikido_zen.context import Context from aikido_zen.helpers.logging import logger -from ..functions.request_handler import request_handler +from ..functions.on_init_handler import on_init_handler def run_init_stage(request): - """Parse request and body, run "init" stage with request_handler""" + """Parse request and body, run the on_init_handler""" body = None try: # try-catch loading of form parameters, this is to fix issue with DATA_UPLOAD_MAX_NUMBER_FIELDS : @@ -38,17 +38,13 @@ def run_init_stage(request): # In a separate try-catch we set the context : try: context = None - if ( - hasattr(request, "scope") and request.scope is not None - ): # This request is an ASGI request + if hasattr(request, "scope") and request.scope is not None: + # This request is an ASGI request context = Context(req=request.scope, body=body, source="django_async") - elif hasattr(request, "META") and request.META is not None: # WSGI request + elif hasattr(request, "META") and request.META is not None: + # This request is a WSGI request context = Context(req=request.META, body=body, source="django") - else: - return - context.set_as_current_context() - # Init stage needs to be run with context already set : - request_handler(stage="init") + on_init_handler(context) except Exception as e: logger.debug("Exception occurred in run_init_stage function (Django): %s", e) diff --git a/aikido_zen/sources/flask.py b/aikido_zen/sources/flask.py index 6fc862749..f98ee8090 100644 --- a/aikido_zen/sources/flask.py +++ b/aikido_zen/sources/flask.py @@ -9,6 +9,7 @@ from aikido_zen.background_process.packages import is_package_compatible, ANY_VERSION from aikido_zen.context import get_current_context import aikido_zen.sources.functions.request_handler as funcs +from aikido_zen.sources.functions.on_init_handler import on_init_handler def aik_full_dispatch_request(*args, former_full_dispatch_request=None, **kwargs): @@ -84,9 +85,8 @@ def aikido___call__(flask_app, environ, start_response): # We don't want to install werkzeug : # pylint: disable=import-outside-toplevel try: - context1 = Context(req=environ, source="flask") - context1.set_as_current_context() - funcs.request_handler(stage="init") + context = Context(req=environ, source="flask") + on_init_handler(context) except Exception as e: logger.debug("Exception on aikido __call__ function : %s", e) res = flask_app.wsgi_app(environ, start_response) diff --git a/aikido_zen/sources/flask_test.py b/aikido_zen/sources/flask_test.py index 6dd9e71db..41ba222e2 100644 --- a/aikido_zen/sources/flask_test.py +++ b/aikido_zen/sources/flask_test.py @@ -111,11 +111,10 @@ def hello(user, age): "CONTENT_TYPE": "application/json", } calls = mock_request_handler.call_args_list - assert len(calls) == 3 - assert calls[0][1]["stage"] == "init" - assert calls[1][1]["stage"] == "pre_response" - assert calls[2][1]["stage"] == "post_response" - assert calls[2][1]["status_code"] == 200 + assert len(calls) == 2 + assert calls[0][1]["stage"] == "pre_response" + assert calls[1][1]["stage"] == "post_response" + assert calls[1][1]["status_code"] == 200 assert get_current_context().route_params["user"] == "JohnDoe" assert get_current_context().route_params["age"] == "30" @@ -146,11 +145,10 @@ def test_flask_all_3_func_with_malformed_cookie(): assert get_current_context().cookies == {"\u0000" * 10: ""} calls = mock_request_handler.call_args_list - assert len(calls) == 3 - assert calls[0][1]["stage"] == "init" - assert calls[1][1]["stage"] == "pre_response" - assert calls[2][1]["stage"] == "post_response" - assert calls[2][1]["status_code"] == 404 + assert len(calls) == 2 + assert calls[0][1]["stage"] == "pre_response" + assert calls[1][1]["stage"] == "post_response" + assert calls[1][1]["status_code"] == 404 def test_flask_all_3_func_with_invalid_body(): @@ -184,11 +182,10 @@ def test_flask_all_3_func_with_invalid_body(): "CONTENT_TYPE": "application/json", } calls = mock_request_handler.call_args_list - assert len(calls) == 3 - assert calls[0][1]["stage"] == "init" - assert calls[1][1]["stage"] == "pre_response" - assert calls[2][1]["stage"] == "post_response" - assert calls[2][1]["status_code"] == 404 + assert len(calls) == 2 + assert calls[0][1]["stage"] == "pre_response" + assert calls[1][1]["stage"] == "post_response" + assert calls[1][1]["status_code"] == 404 def test_flask_all_3_func(): @@ -218,11 +215,10 @@ def test_flask_all_3_func(): "CONTENT_TYPE": "application/x-www-form-urlencoded", } calls = mock_request_handler.call_args_list - assert len(calls) == 3 - assert calls[0][1]["stage"] == "init" - assert calls[1][1]["stage"] == "pre_response" - assert calls[2][1]["stage"] == "post_response" - assert calls[2][1]["status_code"] == 404 + assert len(calls) == 2 + assert calls[0][1]["stage"] == "pre_response" + assert calls[1][1]["stage"] == "post_response" + assert calls[1][1]["status_code"] == 404 def test_startup_flask(): diff --git a/aikido_zen/sources/functions/on_init_handler.py b/aikido_zen/sources/functions/on_init_handler.py new file mode 100644 index 000000000..b883af277 --- /dev/null +++ b/aikido_zen/sources/functions/on_init_handler.py @@ -0,0 +1,23 @@ +from aikido_zen.context import Context +from aikido_zen.thread.thread_cache import get_cache + + +def on_init_handler(context: Context): + """ + On-Init Handler should be called after a context has been created, the function will : + - Store context + - Renew thread cache if necessary and store the hits + - Check if IP is bypassed + """ + if context is None: + return + + cache = get_cache() + if cache is not None and cache.is_bypassed_ip(context.remote_address): + return # Do not store the context of bypassed IPs, skip request processing. + context.set_as_current_context() + + if cache is not None: + # Only check the TTL at the start of a request. + cache.renew_if_ttl_expired() + cache.increment_stats() diff --git a/aikido_zen/sources/functions/on_init_handler_test.py b/aikido_zen/sources/functions/on_init_handler_test.py new file mode 100644 index 000000000..81090dbd2 --- /dev/null +++ b/aikido_zen/sources/functions/on_init_handler_test.py @@ -0,0 +1,51 @@ +import pytest +from unittest.mock import MagicMock, patch +from aikido_zen.context import Context, get_current_context, current_context +from .on_init_handler import on_init_handler +from ...thread.thread_cache import ThreadCache, threadlocal_storage + + +@pytest.fixture +def mock_context(): + """Fixture to create a mock context.""" + context = MagicMock(spec=Context) + context.remote_address = "192.168.1.1" # Example IP + return context + + +@pytest.fixture(autouse=True) +def run_around_tests(): + current_context.set(None) + threadlocal_storage.cache = None + yield + current_context.set(None) + threadlocal_storage.cache = None + + +def test_on_init_handler_with_none_context(): + """Test that the function returns early if context is None.""" + result = on_init_handler(None) + assert result is None # No return value, just ensure it doesn't raise an error + + +def test_on_init_handler_with_bypassed_ip(mock_context): + """Test that the function returns early if the IP is bypassed.""" + cache = ThreadCache() + cache.config.set_bypassed_ips(["192.168.1.1"]) + on_init_handler(mock_context) + mock_context.set_as_current_context.assert_not_called() + + +def test_on_init_handler_with_valid_context(mock_context): + """Test that the function processes the context correctly when IP is not bypassed.""" + cache = ThreadCache() + on_init_handler(mock_context) + mock_context.set_as_current_context.assert_called_once() # Should set context + assert cache.reqs == 1 + on_init_handler(mock_context) + assert cache.reqs == 2 + + +def test_on_init_handler_with_valid_context_but_empty_thread_cache(mock_context): + on_init_handler(mock_context) + mock_context.set_as_current_context.assert_called_once() # Should set context diff --git a/aikido_zen/sources/functions/request_handler.py b/aikido_zen/sources/functions/request_handler.py index 0ea7cb006..a2b2dc887 100644 --- a/aikido_zen/sources/functions/request_handler.py +++ b/aikido_zen/sources/functions/request_handler.py @@ -14,13 +14,6 @@ def request_handler(stage, status_code=0): """This will check for rate limiting, Allowed IP's, useful routes, etc.""" try: - if stage == "init": - # Initial stage of the request, called after context is stored. - thread_cache = get_cache() - thread_cache.renew_if_ttl_expired() # Only check TTL at the start of a request. - if ctx.get_current_context() and thread_cache: - thread_cache.increment_stats() # Increment request statistics if a context exists. - if stage == "pre_response": return pre_response() if stage == "post_response": diff --git a/aikido_zen/sources/quart.py b/aikido_zen/sources/quart.py index a2057362f..540f09183 100644 --- a/aikido_zen/sources/quart.py +++ b/aikido_zen/sources/quart.py @@ -7,6 +7,7 @@ from aikido_zen.helpers.logging import logger from aikido_zen.context import Context, get_current_context from aikido_zen.background_process.packages import is_package_compatible, ANY_VERSION +from .functions.on_init_handler import on_init_handler from .functions.request_handler import request_handler @@ -17,10 +18,8 @@ async def aikido___call___wrapper(former_call, quart_app, scope, receive, send): try: if scope["type"] != "http": return await former_call(quart_app, scope, receive, send) - context1 = Context(req=scope, source="quart") - context1.set_as_current_context() - - request_handler(stage="init") + context = Context(req=scope, source="quart") + on_init_handler(context) except Exception as e: logger.debug("Exception on aikido __call__ function : %s", e) return await former_call(quart_app, scope, receive, send) diff --git a/aikido_zen/sources/starlette/__init__.py b/aikido_zen/sources/starlette/__init__.py index e22cb1891..e94951511 100644 --- a/aikido_zen/sources/starlette/__init__.py +++ b/aikido_zen/sources/starlette/__init__.py @@ -2,7 +2,7 @@ Init.py file for starlette module --- Starlette wrapping is subdivided in two parts : -- starlette.applications : Wraps __call__ on Starlette class to run "init" stage. +- starlette.applications : Wraps __call__ on Starlette class to run the on_init_handler - starlette.routing : request_response function : Run pre_response code and also runs post_response code after getting response from user function. diff --git a/aikido_zen/sources/starlette/starlette_applications.py b/aikido_zen/sources/starlette/starlette_applications.py index ce5ca8f66..3eeb573e2 100644 --- a/aikido_zen/sources/starlette/starlette_applications.py +++ b/aikido_zen/sources/starlette/starlette_applications.py @@ -1,17 +1,17 @@ -"""Wraps starlette.applications for initial request_handler""" +"""Wraps starlette.applications to create a context object and run the on_init_handler""" import copy import aikido_zen.importhook as importhook from aikido_zen.helpers.logging import logger from aikido_zen.context import Context -from ..functions.request_handler import request_handler +from ..functions.on_init_handler import on_init_handler @importhook.on_import("starlette.applications") def on_starlette_import(starlette): """ Hook 'n wrap on `starlette.applications` - Our goal is to wrap the __call__ function of the Starlette class + Our goal is to wrap the __call__ function of the Starlette class, so we can create a Context object. """ modified_starlette = importhook.copy_module(starlette) former_call = copy.deepcopy(starlette.Starlette.__call__) @@ -28,9 +28,8 @@ async def aik_call_wrapper(former_call, app, scope, receive, send): try: if scope["type"] != "http": return await former_call(app, scope, receive, send) - context1 = Context(req=scope, source="starlette") - context1.set_as_current_context() - request_handler(stage="init") + context = Context(req=scope, source="starlette") + on_init_handler(context) except Exception as e: logger.debug("Exception on aikido __call__ function : %s", e) return await former_call(app, scope, receive, send) diff --git a/aikido_zen/vulnerabilities/__init__.py b/aikido_zen/vulnerabilities/__init__.py index 37de6656e..031537a1f 100644 --- a/aikido_zen/vulnerabilities/__init__.py +++ b/aikido_zen/vulnerabilities/__init__.py @@ -52,9 +52,6 @@ def run_vulnerability_scan(kind, op, args): ): # The client turned protection off for this route, not scanning return - if thread_cache.is_bypassed_ip(context.remote_address): - # This IP is on the bypass list, not scanning - return error_type = AikidoException # Default error error_args = tuple() diff --git a/aikido_zen/vulnerabilities/init_test.py b/aikido_zen/vulnerabilities/init_test.py index 32264c46e..f339bfa5f 100644 --- a/aikido_zen/vulnerabilities/init_test.py +++ b/aikido_zen/vulnerabilities/init_test.py @@ -80,16 +80,6 @@ def test_ssrf(caplog, get_context): run_vulnerability_scan(kind="ssrf", op="test", args=tuple()) -def test_lifecycle_cache_bypassed_ip(caplog, get_context): - get_context.set_as_current_context() - cache = ThreadCache() - cache.config.bypassed_ips = IPList() - cache.config.bypassed_ips.add("198.51.100.23") - assert cache.is_bypassed_ip("198.51.100.23") - run_vulnerability_scan(kind="test", op="test", args=tuple()) - assert len(caplog.text) == 0 - - def test_sql_injection(caplog, get_context, monkeypatch): get_context.set_as_current_context() cache = ThreadCache() diff --git a/aikido_zen/vulnerabilities/ssrf/inspect_getaddrinfo_result.py b/aikido_zen/vulnerabilities/ssrf/inspect_getaddrinfo_result.py index a263a96a1..c5b1c6a4b 100644 --- a/aikido_zen/vulnerabilities/ssrf/inspect_getaddrinfo_result.py +++ b/aikido_zen/vulnerabilities/ssrf/inspect_getaddrinfo_result.py @@ -29,10 +29,6 @@ def inspect_getaddrinfo_result(dns_results, hostname, port): context = get_current_context() if not context: return # Context should be set to check user input. - if get_cache() and get_cache().is_bypassed_ip(context.remote_address): - # We check for bypassed ip's here since it is not checked for us - # in run_vulnerability_scan due to the exception for SSRF (see above code) - return # attack_findings is an object containing source, pathToPayload and payload. attack_findings = find_hostname_in_context(hostname, context, port)