diff --git a/aikido_firewall/__init__.py b/aikido_firewall/__init__.py index 292c8aed9..aa4910279 100644 --- a/aikido_firewall/__init__.py +++ b/aikido_firewall/__init__.py @@ -15,11 +15,13 @@ load_dotenv() -def protect(): +def protect(module="any"): """Start Aikido agent""" # Import sources import aikido_firewall.sources.django - import aikido_firewall.sources.flask + + if module != "django": + import aikido_firewall.sources.flask # Import sinks import aikido_firewall.sinks.pymysql diff --git a/aikido_firewall/context/__init__.py b/aikido_firewall/context/__init__.py index 723b0d1a7..53910f4ab 100644 --- a/aikido_firewall/context/__init__.py +++ b/aikido_firewall/context/__init__.py @@ -4,6 +4,7 @@ import threading +SUPPORTED_SOURCES = ["django", "flask"] local = threading.local() @@ -15,21 +16,45 @@ def get_current_context(): return None +def parse_headers(headers): + """Parse EnvironHeaders object into a dict""" + if isinstance(headers, dict): + return headers + return dict(zip(headers.keys(), headers.values())) + + class Context: """ A context object, it stores everything that is important for vulnerability detection """ - def __init__(self, req): + def __init__(self, req, source): + if not source in SUPPORTED_SOURCES: + raise ValueError(f"Source {source} not supported") + self.source = source self.method = req.method + self.headers = parse_headers(req.headers) + if source == "flask": + self.set_flask_attrs(req) + elif source == "django": + self.set_django_attrs(req) + + def set_django_attrs(self, req): + """set properties that are specific to django""" + self.remote_address = req.META.get("REMOTE_ADDR") + self.url = req.build_absolute_uri() + self.body = dict(req.POST) + self.query = dict(req.GET) + self.cookies = req.COOKIES + + def set_flask_attrs(self, req): + """Set properties that are specific to flask""" self.remote_address = req.remote_addr self.url = req.url - self.body = req.form - self.headers = req.headers - self.query = req.args - self.cookies = req.cookies - self.source = "flask" + self.body = req.form.to_dict() + self.query = req.args.to_dict() + self.cookies = req.cookies.to_dict() def __reduce__(self): return ( diff --git a/aikido_firewall/context/init_test.py b/aikido_firewall/context/init_test.py index 0c247666e..f73c06e33 100644 --- a/aikido_firewall/context/init_test.py +++ b/aikido_firewall/context/init_test.py @@ -2,36 +2,64 @@ from aikido_firewall.context import Context, get_current_context -@pytest.fixture -def sample_request(): - # Mock a sample request object for testing - class Request: - def __init__(self): - self.method = "GET" - self.remote_addr = "127.0.0.1" - self.url = "/test" - self.form = {} - self.headers = {} - self.args = {} - self.cookies = {} - - return Request() - - def test_get_current_context_no_context(): # Test get_current_context() when no context is set assert get_current_context() is None -def test_set_as_current_context(sample_request): +def test_set_as_current_context(mocker): # Test set_as_current_context() method - context = Context(sample_request) + sample_request = mocker.MagicMock() + context = Context(sample_request, "flask") context.set_as_current_context() assert get_current_context() == context -def test_get_current_context_with_context(sample_request): +def test_get_current_context_with_context(mocker): # Test get_current_context() when a context is set - context = Context(sample_request) + sample_request = mocker.MagicMock() + context = Context(sample_request, "flask") context.set_as_current_context() assert get_current_context() == context + + +def test_context_init_flask(mocker): + req = mocker.MagicMock() + req.method = "GET" + req.remote_addr = "127.0.0.1" + req.url = "http://example.com" + req.form.to_dict.return_value = {"key": "value"} + req.headers = {"Content-Type": "application/json"} + req.args.to_dict.return_value = {"key": "value"} + req.cookies.to_dict.return_value = {"cookie": "value"} + + context = Context(req, "flask") + assert context.source == "flask" + assert context.method == "GET" + assert context.remote_address == "127.0.0.1" + assert context.url == "http://example.com" + assert context.body == {"key": "value"} + assert context.headers == {"Content-Type": "application/json"} + assert context.query == {"key": "value"} + assert context.cookies == {"cookie": "value"} + + +def test_context_init_django(mocker): + req = mocker.MagicMock() + req.method = "POST" + req.META.get.return_value = "127.0.0.1" + req.build_absolute_uri.return_value = "http://example.com" + req.POST = {"key": "value"} + req.headers = {"Content-Type": "application/json"} + req.GET = {"key": "value"} + req.COOKIES = {"cookie": "value"} + + context = Context(req, "django") + assert context.source == "django" + assert context.method == "POST" + assert context.remote_address == "127.0.0.1" + assert context.url == "http://example.com" + assert context.body == {"key": "value"} + assert context.headers == {"Content-Type": "application/json"} + assert context.query == {"key": "value"} + assert context.cookies == {"cookie": "value"} diff --git a/aikido_firewall/middleware/django.py b/aikido_firewall/middleware/django.py index c5fe8bf2b..276387b36 100644 --- a/aikido_firewall/middleware/django.py +++ b/aikido_firewall/middleware/django.py @@ -4,6 +4,7 @@ """ from aikido_firewall.helpers.logging import logger +from aikido_firewall.context import Context class AikidoMiddleware: @@ -16,6 +17,8 @@ def __init__(self, get_response): def __call__(self, request, *args, **kwargs): logger.debug("Aikido middleware for `django` was called : __call__") + context = Context(request, "django") + context.set_as_current_context() return self.get_response(request) def process_exception(self, request, exception): diff --git a/aikido_firewall/sources/flask.py b/aikido_firewall/sources/flask.py index 99a477320..d5d92cced 100644 --- a/aikido_firewall/sources/flask.py +++ b/aikido_firewall/sources/flask.py @@ -21,7 +21,7 @@ def __init__(self): def dispatch(self, request, call_next): """Dispatch function""" logger.debug("Aikido middleware for `flask` was called") - context = Context(request) + context = Context(request, "flask") context.set_as_current_context() response = call_next(request) diff --git a/sample-apps/django-mysql/manage.py b/sample-apps/django-mysql/manage.py index a3de10f7d..f66fda4b7 100755 --- a/sample-apps/django-mysql/manage.py +++ b/sample-apps/django-mysql/manage.py @@ -1,6 +1,8 @@ #!/usr/bin/env python """Django's command-line utility for administrative tasks.""" import aikido_firewall # Aikido module +aikido_firewall.protect("django") + import os import sys diff --git a/sample-apps/django-mysql/sample_app/templates/app/create_dog.html b/sample-apps/django-mysql/sample_app/templates/app/create_dog.html new file mode 100644 index 000000000..b2375a804 --- /dev/null +++ b/sample-apps/django-mysql/sample_app/templates/app/create_dog.html @@ -0,0 +1,6 @@ +