Skip to content

Create a unified on_init_handler, that does not set context if IP is bypassed #346

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions aikido_zen/ratelimiting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@ def should_ratelimit_request(route_metadata, remote_address, user, connection_ma

max_requests = int(endpoint["rateLimiting"]["maxRequests"])
windows_size_in_ms = int(endpoint["rateLimiting"]["windowSizeInMS"])
is_bypassed_ip = connection_manager.conf.is_bypassed_ip(remote_address)

if is_bypassed_ip:
return {"block": False}
if user:
uid = user["id"]
method = endpoint.get("method")
Expand Down
77 changes: 0 additions & 77 deletions aikido_zen/ratelimiting/init_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,39 +61,6 @@ def test_rate_limits_by_ip():
}


def test_rate_limiting_ip_allowed():
cm = create_connection_manager(
[
{
"method": "POST",
"route": "/login",
"forceProtectionOff": False,
"rateLimiting": {
"enabled": True,
"maxRequests": 3,
"windowSizeInMS": 1000,
},
},
],
["1.2.3.4"],
)

# Test requests from allowed IP
route_metadata = create_route_metadata()
assert should_ratelimit_request(route_metadata, "1.2.3.4", None, cm) == {
"block": False
}
assert should_ratelimit_request(route_metadata, "1.2.3.4", None, cm) == {
"block": False
}
assert should_ratelimit_request(route_metadata, "1.2.3.4", None, cm) == {
"block": False
}
assert should_ratelimit_request(route_metadata, "1.2.3.4", None, cm) == {
"block": False
}


def test_rate_limiting_by_user(user):
cm = create_connection_manager(
[
Expand Down Expand Up @@ -304,47 +271,3 @@ def test_rate_limiting_same_ip_different_users():
assert should_ratelimit_request(metadata, "1.2.3.4", {"id": "123456"}, cm) == {
"block": False
}


def test_does_not_ratelimit_bypassed_ip_with_user():
pass # Really?


def test_works_with_setuser_after_first_ratelimit():
pass


import pytest


def test_rate_limiting_bypassed_ip_with_user():
cm = create_connection_manager(
[
{
"method": "POST",
"route": "/login",
"forceProtectionOff": False,
"rateLimiting": {
"enabled": True,
"maxRequests": 3,
"windowSizeInMS": 1000,
},
},
],
["1.2.3.4"],
)

# All requests from the bypassed IP should not be blocked
metadata = create_route_metadata(route="/login", method="POST")
assert should_ratelimit_request(metadata, "1.2.3.4", {"id": "123"}, cm) == {
"block": False
}
assert should_ratelimit_request(metadata, "1.2.3.4", {"id": "123"}, cm) == {
"block": False
}
assert should_ratelimit_request(metadata, "1.2.3.4", {"id": "123"}, cm) == {
"block": False
}
assert should_ratelimit_request(metadata, "1.2.3.4", {"id": "123"}, cm) == {
"block": False
}
18 changes: 7 additions & 11 deletions aikido_zen/sources/django/run_init_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import json
from aikido_zen.context import Context
from aikido_zen.helpers.logging import logger
from ..functions.request_handler import request_handler
from ..functions.on_init_handler import on_init_handler


def run_init_stage(request):
"""Parse request and body, run "init" stage with request_handler"""
"""Parse request and body, run the on_init_handler"""
body = None
try:
# try-catch loading of form parameters, this is to fix issue with DATA_UPLOAD_MAX_NUMBER_FIELDS :
Expand Down Expand Up @@ -38,17 +38,13 @@ def run_init_stage(request):
# In a separate try-catch we set the context :
try:
context = None
if (
hasattr(request, "scope") and request.scope is not None
): # This request is an ASGI request
if hasattr(request, "scope") and request.scope is not None:
# This request is an ASGI request
context = Context(req=request.scope, body=body, source="django_async")
elif hasattr(request, "META") and request.META is not None: # WSGI request
elif hasattr(request, "META") and request.META is not None:
# This request is a WSGI request
context = Context(req=request.META, body=body, source="django")
else:
return
context.set_as_current_context()

# Init stage needs to be run with context already set :
request_handler(stage="init")
on_init_handler(context)
except Exception as e:
logger.debug("Exception occurred in run_init_stage function (Django): %s", e)
6 changes: 3 additions & 3 deletions aikido_zen/sources/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from aikido_zen.background_process.packages import is_package_compatible, ANY_VERSION
from aikido_zen.context import get_current_context
import aikido_zen.sources.functions.request_handler as funcs
from aikido_zen.sources.functions.on_init_handler import on_init_handler


def aik_full_dispatch_request(*args, former_full_dispatch_request=None, **kwargs):
Expand Down Expand Up @@ -84,9 +85,8 @@ def aikido___call__(flask_app, environ, start_response):
# We don't want to install werkzeug :
# pylint: disable=import-outside-toplevel
try:
context1 = Context(req=environ, source="flask")
context1.set_as_current_context()
funcs.request_handler(stage="init")
context = Context(req=environ, source="flask")
on_init_handler(context)
except Exception as e:
logger.debug("Exception on aikido __call__ function : %s", e)
res = flask_app.wsgi_app(environ, start_response)
Expand Down
36 changes: 16 additions & 20 deletions aikido_zen/sources/flask_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,10 @@ def hello(user, age):
"CONTENT_TYPE": "application/json",
}
calls = mock_request_handler.call_args_list
assert len(calls) == 3
assert calls[0][1]["stage"] == "init"
assert calls[1][1]["stage"] == "pre_response"
assert calls[2][1]["stage"] == "post_response"
assert calls[2][1]["status_code"] == 200
assert len(calls) == 2
assert calls[0][1]["stage"] == "pre_response"
assert calls[1][1]["stage"] == "post_response"
assert calls[1][1]["status_code"] == 200

assert get_current_context().route_params["user"] == "JohnDoe"
assert get_current_context().route_params["age"] == "30"
Expand Down Expand Up @@ -146,11 +145,10 @@ def test_flask_all_3_func_with_malformed_cookie():
assert get_current_context().cookies == {"\u0000" * 10: ""}

calls = mock_request_handler.call_args_list
assert len(calls) == 3
assert calls[0][1]["stage"] == "init"
assert calls[1][1]["stage"] == "pre_response"
assert calls[2][1]["stage"] == "post_response"
assert calls[2][1]["status_code"] == 404
assert len(calls) == 2
assert calls[0][1]["stage"] == "pre_response"
assert calls[1][1]["stage"] == "post_response"
assert calls[1][1]["status_code"] == 404


def test_flask_all_3_func_with_invalid_body():
Expand Down Expand Up @@ -184,11 +182,10 @@ def test_flask_all_3_func_with_invalid_body():
"CONTENT_TYPE": "application/json",
}
calls = mock_request_handler.call_args_list
assert len(calls) == 3
assert calls[0][1]["stage"] == "init"
assert calls[1][1]["stage"] == "pre_response"
assert calls[2][1]["stage"] == "post_response"
assert calls[2][1]["status_code"] == 404
assert len(calls) == 2
assert calls[0][1]["stage"] == "pre_response"
assert calls[1][1]["stage"] == "post_response"
assert calls[1][1]["status_code"] == 404


def test_flask_all_3_func():
Expand Down Expand Up @@ -218,11 +215,10 @@ def test_flask_all_3_func():
"CONTENT_TYPE": "application/x-www-form-urlencoded",
}
calls = mock_request_handler.call_args_list
assert len(calls) == 3
assert calls[0][1]["stage"] == "init"
assert calls[1][1]["stage"] == "pre_response"
assert calls[2][1]["stage"] == "post_response"
assert calls[2][1]["status_code"] == 404
assert len(calls) == 2
assert calls[0][1]["stage"] == "pre_response"
assert calls[1][1]["stage"] == "post_response"
assert calls[1][1]["status_code"] == 404


def test_startup_flask():
Expand Down
23 changes: 23 additions & 0 deletions aikido_zen/sources/functions/on_init_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from aikido_zen.context import Context
from aikido_zen.thread.thread_cache import get_cache


def on_init_handler(context: Context):
"""
On-Init Handler should be called after a context has been created, the function will :
- Store context
- Renew thread cache if necessary and store the hits
- Check if IP is bypassed
"""
if context is None:
return

cache = get_cache()
if cache is not None and cache.is_bypassed_ip(context.remote_address):
return # Do not store the context of bypassed IPs, skip request processing.
context.set_as_current_context()

if cache is not None:
# Only check the TTL at the start of a request.
cache.renew_if_ttl_expired()
cache.increment_stats()
51 changes: 51 additions & 0 deletions aikido_zen/sources/functions/on_init_handler_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pytest
from unittest.mock import MagicMock, patch
from aikido_zen.context import Context, get_current_context, current_context
from .on_init_handler import on_init_handler
from ...thread.thread_cache import ThreadCache, threadlocal_storage


@pytest.fixture
def mock_context():
"""Fixture to create a mock context."""
context = MagicMock(spec=Context)
context.remote_address = "192.168.1.1" # Example IP
return context


@pytest.fixture(autouse=True)
def run_around_tests():
current_context.set(None)
threadlocal_storage.cache = None
yield
current_context.set(None)
threadlocal_storage.cache = None


def test_on_init_handler_with_none_context():
"""Test that the function returns early if context is None."""
result = on_init_handler(None)
assert result is None # No return value, just ensure it doesn't raise an error


def test_on_init_handler_with_bypassed_ip(mock_context):
"""Test that the function returns early if the IP is bypassed."""
cache = ThreadCache()
cache.config.set_bypassed_ips(["192.168.1.1"])
on_init_handler(mock_context)
mock_context.set_as_current_context.assert_not_called()


def test_on_init_handler_with_valid_context(mock_context):
"""Test that the function processes the context correctly when IP is not bypassed."""
cache = ThreadCache()
on_init_handler(mock_context)
mock_context.set_as_current_context.assert_called_once() # Should set context
assert cache.reqs == 1
on_init_handler(mock_context)
assert cache.reqs == 2


def test_on_init_handler_with_valid_context_but_empty_thread_cache(mock_context):
on_init_handler(mock_context)
mock_context.set_as_current_context.assert_called_once() # Should set context
7 changes: 0 additions & 7 deletions aikido_zen/sources/functions/request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,6 @@
def request_handler(stage, status_code=0):
"""This will check for rate limiting, Allowed IP's, useful routes, etc."""
try:
if stage == "init":
# Initial stage of the request, called after context is stored.
thread_cache = get_cache()
thread_cache.renew_if_ttl_expired() # Only check TTL at the start of a request.
if ctx.get_current_context() and thread_cache:
thread_cache.increment_stats() # Increment request statistics if a context exists.

if stage == "pre_response":
return pre_response()
if stage == "post_response":
Expand Down
7 changes: 3 additions & 4 deletions aikido_zen/sources/quart.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from aikido_zen.helpers.logging import logger
from aikido_zen.context import Context, get_current_context
from aikido_zen.background_process.packages import is_package_compatible, ANY_VERSION
from .functions.on_init_handler import on_init_handler
from .functions.request_handler import request_handler


Expand All @@ -17,10 +18,8 @@
try:
if scope["type"] != "http":
return await former_call(quart_app, scope, receive, send)
context1 = Context(req=scope, source="quart")
context1.set_as_current_context()

request_handler(stage="init")
context = Context(req=scope, source="quart")
on_init_handler(context)

Check warning on line 22 in aikido_zen/sources/quart.py

View check run for this annotation

Codecov / codecov/patch

aikido_zen/sources/quart.py#L21-L22

Added lines #L21 - L22 were not covered by tests
except Exception as e:
logger.debug("Exception on aikido __call__ function : %s", e)
return await former_call(quart_app, scope, receive, send)
Expand Down
2 changes: 1 addition & 1 deletion aikido_zen/sources/starlette/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Init.py file for starlette module
---
Starlette wrapping is subdivided in two parts :
- starlette.applications : Wraps __call__ on Starlette class to run "init" stage.
- starlette.applications : Wraps __call__ on Starlette class to run the on_init_handler
- starlette.routing : request_response function : Run pre_response code and
also runs post_response code after getting response from user function.

Expand Down
11 changes: 5 additions & 6 deletions aikido_zen/sources/starlette/starlette_applications.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
"""Wraps starlette.applications for initial request_handler"""
"""Wraps starlette.applications to create a context object and run the on_init_handler"""

import copy
import aikido_zen.importhook as importhook
from aikido_zen.helpers.logging import logger
from aikido_zen.context import Context
from ..functions.request_handler import request_handler
from ..functions.on_init_handler import on_init_handler

Check warning on line 7 in aikido_zen/sources/starlette/starlette_applications.py

View check run for this annotation

Codecov / codecov/patch

aikido_zen/sources/starlette/starlette_applications.py#L7

Added line #L7 was not covered by tests


@importhook.on_import("starlette.applications")
def on_starlette_import(starlette):
"""
Hook 'n wrap on `starlette.applications`
Our goal is to wrap the __call__ function of the Starlette class
Our goal is to wrap the __call__ function of the Starlette class, so we can create a Context object.
"""
modified_starlette = importhook.copy_module(starlette)
former_call = copy.deepcopy(starlette.Starlette.__call__)
Expand All @@ -28,9 +28,8 @@
try:
if scope["type"] != "http":
return await former_call(app, scope, receive, send)
context1 = Context(req=scope, source="starlette")
context1.set_as_current_context()
request_handler(stage="init")
context = Context(req=scope, source="starlette")
on_init_handler(context)

Check warning on line 32 in aikido_zen/sources/starlette/starlette_applications.py

View check run for this annotation

Codecov / codecov/patch

aikido_zen/sources/starlette/starlette_applications.py#L31-L32

Added lines #L31 - L32 were not covered by tests
except Exception as e:
logger.debug("Exception on aikido __call__ function : %s", e)
return await former_call(app, scope, receive, send)
3 changes: 0 additions & 3 deletions aikido_zen/vulnerabilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,6 @@ def run_vulnerability_scan(kind, op, args):
):
# The client turned protection off for this route, not scanning
return
if thread_cache.is_bypassed_ip(context.remote_address):
# This IP is on the bypass list, not scanning
return

error_type = AikidoException # Default error
error_args = tuple()
Expand Down
10 changes: 0 additions & 10 deletions aikido_zen/vulnerabilities/init_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,6 @@ def test_ssrf(caplog, get_context):
run_vulnerability_scan(kind="ssrf", op="test", args=tuple())


def test_lifecycle_cache_bypassed_ip(caplog, get_context):
get_context.set_as_current_context()
cache = ThreadCache()
cache.config.bypassed_ips = IPList()
cache.config.bypassed_ips.add("198.51.100.23")
assert cache.is_bypassed_ip("198.51.100.23")
run_vulnerability_scan(kind="test", op="test", args=tuple())
assert len(caplog.text) == 0


def test_sql_injection(caplog, get_context, monkeypatch):
get_context.set_as_current_context()
cache = ThreadCache()
Expand Down
Loading